Skip to content

Commit 77afdcc

Browse files
committed
address comment
1 parent ff1bb3b commit 77afdcc

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -847,15 +847,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
847847
Value mulOffset;
848848
for (auto offsets : *sgOffsets) {
849849
// Multiply offset with stride, broadcast it and add to baseConstVec
850-
SmallVector<Value> muls;
850+
Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
851851
for (size_t i = 0; i < strideConsts.size(); ++i) {
852-
muls.push_back(rewriter.create<arith::MulIOp>(
853-
loc, rewriter.getIndexType(), offsets[i], strideConsts[i]));
854-
}
855-
mulOffset = muls.front();
856-
if (muls.size() > 1) {
852+
Value mul = rewriter.create<arith::MulIOp>(
853+
loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
857854
mulOffset = rewriter.create<arith::AddIOp>(
858-
loc, rewriter.getIndexType(), mulOffset, muls[1]);
855+
loc, rewriter.getIndexType(), mulOffset, mul);
859856
}
860857
// Broadcast to baseConstVec size
861858
auto bcastOffset = rewriter.create<vector::BroadcastOp>(

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,16 @@ gpu.module @test_distribution {
111111
// CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
112112
// CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
113113
// CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
114+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
114115
// CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
115-
// CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[STRIDE1]], %[[STRIDE2]] : index
116+
// CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
116117
// CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
117118
// CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
118119
// CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
120+
// CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
119121
// CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
120-
// CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[STRIDE3]], %[[STRIDE4]] : index
121-
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES2]] : index to vector<2x1xindex>
122+
// CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
123+
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
122124
// CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
123125
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
124126
gpu.return

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,9 @@ gpu.module @test_distribution {
473473
// CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
474474
// CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
475475
// CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
476+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index
476477
// CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
477-
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[STRIDECOL]], %[[STRIDEROW]] : index
478+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[STRIDEROW]] : index
478479
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex>
479480
// CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex>
480481
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
@@ -494,9 +495,10 @@ gpu.module @test_distribution {
494495
// CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
495496
// CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
496497
// CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
498+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index
497499
// CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
498-
// CHECK-DAG: %[[ADDIDX:.*]] = arith.addi %[[MUL5]], %[[MUL6]] : index
499-
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDIDX]] : index to vector<2x2xindex>
500+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
501+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex>
500502
// CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
501503
%cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
502504
[0, 16, 32, 48, 64, 80, 96, 112],
@@ -517,7 +519,8 @@ gpu.module @test_distribution {
517519
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
518520
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
519521
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
520-
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
522+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index
523+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex>
521524
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
522525
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
523526
// CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>

0 commit comments

Comments
 (0)