Skip to content

Commit 015b8a3

Browse files
committed
bug fix in shape cast
1 parent 01fc929 commit 015b8a3

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
10091009
};
10101010

10111011
/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1012-
/// VectorReductionOps.
1012+
/// VectorReductionOps. We also insert layouts for the newly created ops.
10131013
static Value lowerToVectorReductions(TypedValue<VectorType> src,
10141014
TypedValue<VectorType> acc,
10151015
vector::CombiningKind kind,
@@ -1026,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10261026
Value reductionResult = arith::ConstantOp::create(
10271027
rewriter, loc, acc.getType(),
10281028
DenseElementsAttr::get(acc.getType(), zeroAttr));
1029+
// Reduction result should have the same layout as the accumulator.
1030+
xegpu::setDistributeLayoutAttr(cast<OpResult>(reductionResult),
1031+
xegpu::getDistributeLayoutAttr(acc));
10291032
// For each slice of the source, extract the slice vector, do a reduction
10301033
// and, insert the reduced value back to the result vector.
10311034
for (int i = 0; i < nSlices; ++i) {
@@ -1041,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10411044
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
10421045
sliceSizes, {1, 1});
10431046
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1044-
Value slice = vector::ShapeCastOp::create(
1047+
vector::ShapeCastOp slice = vector::ShapeCastOp::create(
10451048
rewriter, loc,
10461049
VectorType::get({nSliceElements}, sourceType.getElementType()),
10471050
extractOp.getResult());
1051+
// Shape cast is currently handled in xegpu side. So layouts must be
1052+
// retained during lowering. Shape cast output has the same layout as the
1053+
// accumulator. Shape cast source has the same layout as the original
1054+
// reduction source.
1055+
// TODO: other ops generated here may also need layout attributes.
1056+
xegpu::setDistributeLayoutAttr(slice->getOpOperand(0),
1057+
xegpu::getDistributeLayoutAttr(src));
1058+
xegpu::setDistributeLayoutAttr(slice->getOpResult(0),
1059+
xegpu::getDistributeLayoutAttr(acc));
1060+
// Extract and reduction results in scalars, so no result layout is needed.
10481061
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1049-
Value reduction =
1050-
vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
1062+
Value reduction = vector::ReductionOp::create(
1063+
rewriter, loc, kind, slice.getResult(), accExtract);
10511064
reductionResult =
10521065
vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
10531066
}
@@ -1229,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
12291242
auto resultDistTy =
12301243
cast<VectorType>(warpOp.getResult(operandNumber).getType());
12311244
xegpu::DistributeLayoutAttr sourceLayout =
1232-
xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
1245+
xegpu::getDistributeLayoutAttr(shapeCastOp->getOpOperand(0));
12331246
xegpu::DistributeLayoutAttr resultLayout =
12341247
xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
12351248
if (!sourceLayout || !resultLayout)

mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
388388
// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
389389
// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
390390
// CHECK-SAME: {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
391-
// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16x1xf32> to vector<16xf32>
391+
// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
392392
// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
393393
// CHECK: %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
394394
// CHECK-SAME: {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
395-
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<16x1xf32> to vector<16xf32>
395+
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] {{.*}} : vector<16x1xf32> to vector<16xf32>
396396
// CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
397397
// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
398398
// CHECK: gpu.yield %[[T7]] : vector<2xf32>

0 commit comments

Comments
 (0)