Skip to content

Commit eb7db0b

Browse files
authored
[mlir][xegpu] Change index arithmetic ops to arith ops. (#170390)
Index ops cause some issues during SIMT distribution because they don't have the `Elementwise` mappable trait. This PR replaces all index arithmetic ops with matching `arith` dialect ops.
1 parent 267865a commit eb7db0b

File tree

8 files changed

+123
-134
lines changed

8 files changed

+123
-134
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "mlir/Dialect/Affine/Utils.h"
1010
#include "mlir/Dialect/Arith/Utils/Utils.h"
11-
#include "mlir/Dialect/Index/IR/IndexOps.h"
1211
#include "mlir/Dialect/Utils/IndexingUtils.h"
1312
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1413
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
@@ -61,7 +60,7 @@ genCoordinates(OpBuilder &builder, Location loc,
6160
// Get the offset of `subShape` within a distribution unit.
6261
SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
6362
llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
64-
return builder.createOrFold<index::MulOp>(
63+
return builder.createOrFold<arith::MulIOp>(
6564
loc, std::get<0>(t),
6665
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
6766
});
@@ -84,7 +83,7 @@ genCoordinates(OpBuilder &builder, Location loc,
8483
// Do not go beyond `srcShape` bounds.
8584
SmallVector<Value> mods = llvm::map_to_vector(
8685
llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
87-
return builder.createOrFold<index::RemUOp>(
86+
return builder.createOrFold<arith::RemUIOp>(
8887
loc, std::get<0>(t),
8988
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
9089
});
@@ -343,7 +342,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
343342
/// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
344343
/// this dimension)
345344
result[dimIdx] =
346-
builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
345+
builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
347346

348347
/// Update remaining for the next dimension by removing what we've already
349348
/// processed. Division tells us "how many complete groups of this dimension
@@ -352,7 +351,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
352351
/// no next dimension to process
353352
if (i < order.size() - 1) {
354353
remaining =
355-
builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
354+
builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
356355
}
357356
}
358357
return result;

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15-
#include "mlir/Dialect/Index/IR/IndexOps.h"
1615
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
1716
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1817
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -527,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
527526
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
528527
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
529528
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
530-
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
529+
results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
531530
}
532531
return results;
533532
}

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ gpu.module @xevm_module{
271271
// CHECK: %[[C2:.*]] = arith.constant 2 : index
272272
// CHECK: %[[C8:.*]] = arith.constant 8 : index
273273
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
274-
// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C8]]
275-
// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C8]]
276-
// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C2]]
277-
// CHECK: %[[REMU3:.*]] = index.remu %[[REMU2]], %[[C2]]
278-
// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C8]]
274+
// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C8]]
275+
// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C8]]
276+
// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C2]]
277+
// CHECK: %[[REMU3:.*]] = arith.remui %[[REMU2]], %[[C2]]
278+
// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C8]]
279279
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
280280
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
281281
gpu.module @xevm_module{
@@ -294,13 +294,13 @@ gpu.module @xevm_module{
294294
// CHECK: %[[C4:.*]] = arith.constant 4 : index
295295
// CHECK: %[[C1:.*]] = arith.constant 1 : index
296296
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
297-
// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C4]]
298-
// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C4]]
299-
// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C4]]
300-
// CHECK: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2]]
301-
// CHECK: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C8]]
302-
// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C4]]
303-
// CHECK: %[[ADD:.*]] = index.add %[[REMU4]], %[[C1]]
297+
// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C4]]
298+
// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C4]]
299+
// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C4]]
300+
// CHECK: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C2]]
301+
// CHECK: %[[REMU3:.*]] = arith.remui %[[MUL]], %[[C8]]
302+
// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C4]]
303+
// CHECK: %[[ADD:.*]] = arith.addi %[[REMU4]], %[[C1]]
304304
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
305305
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
306306
gpu.module @xevm_module{

mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
gpu.module @test {
44
gpu.func @slice_attr() -> vector<128xindex> {
55
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
6-
// CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
7-
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
8-
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
9-
// CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
6+
// CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C8:.*]]
7+
// CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU]], %[[C4:.*]]
8+
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]]
9+
// CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]]
1010
// CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
1111
// CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
1212
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
@@ -16,11 +16,10 @@ gpu.module @test {
1616

1717
gpu.func @nested_slice_attr() -> vector<128xindex> {
1818
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
19-
// CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
20-
// CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
21-
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
22-
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
23-
// CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
19+
// CHECK-DAG: %[[DIVU2:.*]] = arith.divui %[[SGID]], %[[C8:.*]]
20+
// CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU2]], %[[C4:.*]]
21+
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]]
22+
// CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]]
2423
// CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
2524
// CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
2625
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
@@ -29,4 +28,3 @@ gpu.module @test {
2928
}
3029

3130
}
32-

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ gpu.module @test_round_robin_assignment {
1616
gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
1717
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
1818
// CHECK: %[[C4:.*]] = arith.constant 4 : index
19-
// CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
20-
// CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
19+
// CHECK: %[[IDX:.*]] = arith.remui %[[SGID]], %[[C4]]
20+
// CHECK: %[[IDY_DIV:.*]] = arith.divui %[[SGID]], %[[C4]]
2121
// CHECK: %[[C8:.*]] = arith.constant 8 : index
22-
// CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
22+
// CHECK: %[[IDY:.*]] = arith.remui %[[IDY_DIV]], %[[C8]]
2323
// CHECK: %[[C16:.*]] = arith.constant 16 : index
24-
// CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]]
24+
// CHECK: %[[LY:.*]] = arith.muli %[[IDY]], %[[C16]]
2525
// CHECK: %[[C64:.*]] = arith.constant 64 : index
26-
// CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]]
26+
// CHECK: %[[LX:.*]] = arith.muli %[[IDX]], %[[C64]]
2727
// CHECK: %[[C128:.*]] = arith.constant 128 : index
28-
// CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]]
28+
// CHECK: %[[OFFY:.*]] = arith.remui %[[LY]], %[[C128]]
2929
// CHECK: %[[C64_1:.*]] = arith.constant 64 : index
30-
// CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]]
30+
// CHECK: %[[OFFX:.*]] = arith.remui %[[LX]], %[[C64_1]]
3131
// CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
3232
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
3333
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,30 +90,27 @@ gpu.module @test_distribution {
9090
gpu.return
9191
}
9292

93+
// CHECK-LABEL: non_splat_constant
9394
gpu.func @non_splat_constant() {
94-
// CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
95+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{.*}}0{{.*}}, {{.*}}16{{.*}}> : vector<2x1xindex>
9596
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
96-
// CHECK-DAG: %[[REMU1:.*]] = index.remu %[[SGID]], %[[C1:.*]]
97-
// CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C1:.*]]
98-
// CHECK-DAG: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
99-
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2:.*]]
100-
// CHECK-DAG: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C32:.*]]
101-
// CHECK-DAG: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
102-
// CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
103-
// CHECK-DAG: %[[REMU5:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
104-
// CHECK-DAG: %[[REMU6:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
105-
// CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
106-
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
107-
// CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
108-
// CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
109-
// CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
110-
// CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
111-
// CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU5]], %[[C16:.*]] : index
112-
// CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
113-
// CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU6]], %[[C0:.*]] : index
114-
// CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
115-
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
116-
// CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
97+
// CHECK-DAG: %[[T1:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
98+
// CHECK-DAG: %[[T2:.*]] = arith.muli %[[T1]], %[[C2:.*]] : index
99+
// CHECK-DAG: %[[T3:.*]] = arith.remui %[[T2]], %[[C32:.*]] : index
100+
// CHECK-DAG: %[[T4:.*]] = arith.addi %[[T2]], %[[C16:.*]] : index
101+
// CHECK-DAG: %[[T5:.*]] = arith.remui %[[T4]], %[[C32_6:.*]] : index
102+
// CHECK-DAG: %[[T6:.*]] = arith.muli %[[T3]], %[[C16_10:.*]] : index
103+
// CHECK-DAG: %[[T7:.*]] = arith.addi %[[C0_11:.*]], %[[T6]] : index
104+
// CHECK-DAG: %[[T8:.*]] = arith.muli %[[C0_4:.*]], %[[C0_9:.*]] : index
105+
// CHECK-DAG: %[[T9:.*]] = arith.addi %[[T7]], %[[T8]] : index
106+
// CHECK-DAG: %[[T10:.*]] = vector.broadcast %[[T9]] : index to vector<2x1xindex>
107+
// CHECK-DAG: %[[T11:.*]] = arith.addi %[[CST]], %[[T10]] : vector<2x1xindex>
108+
// CHECK-DAG: %[[T12:.*]] = arith.muli %[[T5]], %[[C16_10:.*]] : index
109+
// CHECK-DAG: %[[T13:.*]] = arith.addi %[[C0_12:.*]], %[[T12]] : index
110+
// CHECK-DAG: %[[T14:.*]] = arith.muli %[[C0_8:.*]], %[[C0_9:.*]] : index
111+
// CHECK-DAG: %[[T15:.*]] = arith.addi %[[T13]], %[[T14]] : index
112+
// CHECK-DAG: %[[T16:.*]] = vector.broadcast %[[T15]] : index to vector<2x1xindex>
113+
// CHECK-DAG: %[[T17:.*]] = arith.addi %[[CST]], %[[T16]] : vector<2x1xindex>
117114
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
118115
gpu.return
119116
}
@@ -139,4 +136,3 @@ gpu.module @test_distribution {
139136
gpu.return
140137
}
141138
}
142-

0 commit comments

Comments
 (0)