Skip to content

Commit 1d5f75a

Browse files
committed
Use getValueOrCreateConstantIndexOp
1 parent a092bd3 commit 1d5f75a

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
99

1010
#include "mlir/Dialect/Affine/Utils.h"
11+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1112
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1213
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1314
#include "mlir/Dialect/Index/IR/IndexOps.h"
@@ -86,17 +87,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
8687
assert(localOffset.size() == distUnitBaseAddr.size() &&
8788
"localOffset and distUnitBaseAddr must have the same rank");
8889

89-
// Convert originalOffsets to Value
90-
auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
91-
if (auto val = ofr.dyn_cast<Value>())
92-
return val;
93-
if (auto attr = ofr.dyn_cast<Attribute>()) {
94-
int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
95-
return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
96-
}
97-
llvm_unreachable("Unsupported OpFoldResult kind");
98-
};
99-
10090
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
10191
originalOffsets.end());
10292
size_t rank = localOffset.size();
@@ -110,7 +100,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
110100
rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
111101
Value offsetMod =
112102
rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
113-
Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
103+
Value origOffset = getValueOrCreateConstantIndexOp(
104+
rewriter, loc, originalOffsets[dimIdx]);
114105
Value globalOffset =
115106
rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
116107
globalOffsets[dimIdx] = globalOffset;
@@ -135,8 +126,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
135126
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
136127
} else {
137128
// sgLayout must be present for workgroup-level distribution.
138-
op.emitError("sgLayout attribute is required in layout");
139-
return failure();
129+
return rewriter.notifyMatchFailure(
130+
op, "sgLayout attribute is required in layout");
140131
}
141132

142133
SmallVector<int64_t> sgShape;

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ gpu.module @test_round_robin_assignment {
1515
// CHECK-LABEL: test_load_nd_tdesc
1616
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
1717
gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
18-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
18+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
19+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
1920
// CHECK-COUNT-12: xegpu.load_nd %{{.*}}
2021
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
2122
// CHECK-SAME-COUNT-12: -> vector<2x2xf32>
@@ -45,7 +46,8 @@ gpu.module @test_round_robin_assignment {
4546
// CHECK-LABEL: test_update_nd
4647
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
4748
gpu.func @test_update_nd(%src: memref<24x32xf32>){
48-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
49+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
50+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
4951
// CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
5052
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
5153
// CHECK-NOT: xegpu.update_nd_offset

0 commit comments

Comments
 (0)