Skip to content

Commit 1161e28

Browse files
committed
Feedback
1 parent 8cb5ebe commit 1161e28

File tree

2 files changed

+37
-31
lines changed

2 files changed

+37
-31
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,8 @@ struct WgToSgVectorBroadcastOp
487487
for (auto operand : adaptor.getOperands().front()) {
488488
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
489489
newResultType, operand);
490-
if (!layout.getLaneLayoutAsInt().empty())
490+
if (!layout.getLaneLayoutAsInt().empty() ||
491+
!layout.getLaneDataAsInt().empty())
491492
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
492493
layout.dropSgLayoutAndData());
493494

@@ -546,7 +547,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
546547
for (auto attr : op->getAttrs()) {
547548
if (auto layout =
548549
dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
549-
if (!layout.getLaneLayoutAsInt().empty())
550+
if (!layout.getLaneLayoutAsInt().empty() ||
551+
!layout.getLaneDataAsInt().empty())
550552
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
551553
} else {
552554
state.addAttribute(attr.getName(), attr.getValue());
@@ -738,7 +740,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
738740
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
739741
auto cstOp =
740742
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
741-
if (!layout.getLaneLayoutAsInt().empty())
743+
if (!layout.getLaneLayoutAsInt().empty() ||
744+
!layout.getLaneDataAsInt().empty())
742745
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
743746
layout.dropSgLayoutAndData());
744747
SmallVector<Value> newConsts(count, cstOp);
@@ -923,18 +926,20 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
923926

924927
Value sgId =
925928
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
926-
auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
927-
if (failed(maybeOffsets))
929+
auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
930+
if (failed(sgOffsets))
928931
return failure();
929932

930933
VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
931-
Value base = vector::StepOp::create(rewriter, loc, newTy);
934+
Value steps = vector::StepOp::create(rewriter, loc, newTy);
932935
SmallVector<Value> newOps;
933-
for (auto offsets : *maybeOffsets) {
934-
Value bcast =
936+
for (auto offsets : *sgOffsets) {
937+
// Broadcast the offset scalar to a vector & add to the base steps
938+
Value bcastOffset =
935939
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
936-
Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
937-
newOps.push_back(add);
940+
Value finalSteps =
941+
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
942+
newOps.push_back(finalSteps);
938943
}
939944

940945
rewriter.replaceOpWithMultiple(op, {newOps});
@@ -969,7 +974,8 @@ struct WgToSgVectorShapeCastOp
969974
for (auto src : adaptor.getSource()) {
970975
auto newShapeCast =
971976
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
972-
if (!layout.getLaneLayoutAsInt().empty())
977+
if (!layout.getLaneLayoutAsInt().empty() ||
978+
!layout.getInstDataAsInt().empty())
973979
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
974980
layout.dropSgLayoutAndData());
975981
newShapeCastOps.push_back(newShapeCast.getResult());

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -370,31 +370,31 @@ gpu.module @test_distribution {
370370
// CHECK-LABEL: vector_step_op
371371
gpu.func @vector_step_op_slice_attr() {
372372
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
373-
//CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
374-
//CHECK: [[c32:%.+]] = arith.constant 32 : index
375-
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
376-
//CHECK: [[c0:%.+]] = arith.constant 0 : index
377-
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
378-
//CHECK: [[c128:%.+]] = arith.constant 128 : index
379-
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
380-
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
381-
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
373+
//CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
374+
//CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
375+
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
376+
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
377+
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
378+
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
379+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
380+
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
381+
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
382382
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
383383
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
384384
gpu.return
385385
}
386386

387387
gpu.func @vector_step_op_layout_attr() {
388388
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
389-
//CHECK: [[c16:%.+]] = arith.constant 16 : index
390-
//CHECK: [[c8:%.+]] = arith.constant 8 : index
391-
//CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
392-
//CHECK: [[c0:%.+]] = arith.constant 0 : index
393-
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
394-
//CHECK: [[c128:%.+]] = arith.constant 128 : index
395-
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
396-
//CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
397-
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
389+
//CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
390+
//CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
391+
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
392+
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
393+
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
394+
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
395+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
396+
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
397+
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
398398
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
399399
%step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
400400
gpu.return
@@ -414,8 +414,8 @@ gpu.module @test_distribution {
414414
%load = xegpu.load_nd %tdesc[0, 0]
415415
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
416416
-> vector<256x128xf32>
417-
//CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
418-
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
417+
//CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<2x16x4x8xf32>
418+
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [2, 16, 4, 8]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
419419
gpu.return
420420
}
421421
}

0 commit comments

Comments
 (0)