Skip to content

Commit 1381174

Browse files
committed
Fix CHECKS
1 parent 512478b commit 1381174

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
831831
// Multiply offset with stride, broadcast it and add to baseConstVec
832832
Value mulOffset = rewriter.create<arith::MulIOp>(
833833
loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
834-
auto bcastOffset = rewriter.create<vector::SplatOp>(
834+
auto bcastOffset = rewriter.create<vector::BroadcastOp>(
835835
loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
836836
auto finalConst =
837837
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,26 +102,34 @@ gpu.module @test_distribution {
102102
gpu.func @non_splat_constant() {
103103
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
104104
// CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
105+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
106+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
107+
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
105108
// CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
106109
// CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
107110
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
108111
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
112+
// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
109113
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
110-
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
111-
// CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
112-
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
114+
// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
113115
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
114-
// CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
115-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
116+
// CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]]
117+
// CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
118+
// CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]]
116119
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
117-
// CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
118-
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
119-
// CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
120-
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
121-
// CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
122-
// CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
123-
// CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
124-
// CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
120+
// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
121+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index
122+
// CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index
123+
// CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]]
124+
// CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
125+
// CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]]
126+
// CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index
127+
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index
128+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex>
129+
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex>
130+
// CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index
131+
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex>
132+
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex>
125133
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
126134
gpu.return
127135
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,18 +463,17 @@ gpu.module @test_distribution {
463463
gpu.func @non_splat_constant() {
464464
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
465465
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
466+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
467+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
466468
// CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
467469
// CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
468470
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
469-
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
470-
// CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
471-
// CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
472-
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
473-
// CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
474-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
471+
// CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]]
472+
// CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]]
475473
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
476474
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
477-
// CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
475+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
476+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
478477
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
479478
gpu.return
480479
}

0 commit comments

Comments
 (0)