Skip to content

Commit b8021ed

Browse files
committed
Feedback
1 parent 9457b54 commit b8021ed

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -931,14 +931,23 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
931931
return failure();
932932

933933
VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
934-
Value steps = vector::StepOp::create(rewriter, loc, newTy);
934+
auto steps = vector::StepOp::create(rewriter, loc, newTy);
935935
SmallVector<Value> newOps;
936936
for (auto offsets : *sgOffsets) {
937937
// Broadcast the offset scalar to a vector & add to the base steps
938-
Value bcastOffset =
938+
auto bcastOffset =
939939
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
940-
Value finalSteps =
940+
auto finalSteps =
941941
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
942+
if (!layout.getLaneLayoutAsInt().empty() ||
943+
!layout.getLaneDataAsInt().empty()) {
944+
xegpu::setDistributeLayoutAttr(steps->getResult(0),
945+
layout.dropSgLayoutAndData());
946+
xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
947+
layout.dropSgLayoutAndData());
948+
xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
949+
layout.dropSgLayoutAndData());
950+
}
942951
newOps.push_back(finalSteps);
943952
}
944953

@@ -971,14 +980,12 @@ struct WgToSgVectorShapeCastOp
971980
VectorType::get(sgShape, resultType.getElementType());
972981

973982
// TODO: Add check for compatible layouts in layout attr.
974-
// Only support ShapeCast which expands or reduces unit dims only.
975-
// That is, only allow shape casts where the non-unit dimensions are
976-
// preserved, and any added or removed dimensions must be of size 1.
977983
auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
978984
if (!srcType)
979985
return failure();
980986

981-
auto isUnitOrPreserved = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
987+
// Check that shape_cast only adds/removes unit dimensions,
988+
auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
982989
// Remove all 1s from both shapes and compare the rest.
983990
SmallVector<int64_t> srcNonUnit, dstNonUnit;
984991
for (int64_t d : src)
@@ -990,8 +997,8 @@ struct WgToSgVectorShapeCastOp
990997
return srcNonUnit == dstNonUnit;
991998
};
992999

993-
if (!isUnitOrPreserved(srcType.getShape(), sgShape) ||
994-
!isUnitOrPreserved(sgShape, srcType.getShape()))
1000+
if (!onlyUnitDims(srcType.getShape(), sgShape) ||
1001+
!onlyUnitDims(sgShape, srcType.getShape()))
9951002
return failure();
9961003

9971004
SmallVector<Value> newShapeCastOps;

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,4 +418,13 @@ gpu.module @test_distribution {
418418
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [32, 1, 32, 1]>} : vector<256x128xf32> to vector<256x1x128x1xf32>
419419
gpu.return
420420
}
421+
422+
// CHECK-LABEL: broadcast
423+
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index
424+
gpu.func @broadcast(%arg0: index, %arg1: index) {
425+
%muli = arith.muli %arg0, %arg1 : index
426+
// CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
427+
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
428+
gpu.return
429+
}
421430
}

0 commit comments

Comments
 (0)