Skip to content

Commit 88ff0f9

Browse files
authored
[MLIR][XeGPU] Distribute create_nd_desc op without offset from Wg to Sg (#152351)
This PR adds pattern to distribute the create_nd_desc op without offsets from workgroup (Wg) IR to subgroup (Sg) IR. The round robin distribution logic (involves offset calculation) now will happen in load/store/prefetch nd ops instead of create_nd.
1 parent 31387d6 commit 88ff0f9

File tree

5 files changed

+113
-43
lines changed

5 files changed

+113
-43
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
142142
let builders = [
143143
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
144144

145-
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
146-
"llvm::ArrayRef<OpFoldResult>": $shape,
147-
"llvm::ArrayRef<OpFoldResult>": $strides)>,
148-
149-
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
145+
OpBuilder<(ins "Type": $tdesc, "Value ": $source,
150146
"llvm::ArrayRef<OpFoldResult>": $shape,
151147
"llvm::ArrayRef<OpFoldResult>": $strides)>,
152148

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -156,48 +156,37 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
156156
}
157157

158158
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
159-
Type tdesc, TypedValue<MemRefType> source,
159+
Type tdesc, Value source,
160160
llvm::ArrayRef<OpFoldResult> shape,
161161
llvm::ArrayRef<OpFoldResult> strides) {
162-
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
163-
"Shape and strides must be present and of equal size for ui64 "
164-
"initialization.");
162+
Type srcTy = source.getType();
163+
assert((isa<IntegerType, MemRefType>(srcTy)) &&
164+
"Source has to be either int or memref.");
165165

166-
llvm::SmallVector<int64_t> staticShape;
167-
llvm::SmallVector<int64_t> staticStrides;
168166
llvm::SmallVector<Value> dynamicShape;
169167
llvm::SmallVector<Value> dynamicStrides;
170168

171-
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
172-
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
173-
174-
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
175-
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
176-
177-
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
178-
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
179-
staticStridesAttr);
180-
}
181-
182-
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
183-
Type tdesc, TypedValue<IntegerType> source,
184-
llvm::ArrayRef<OpFoldResult> shape,
185-
llvm::ArrayRef<OpFoldResult> strides) {
186-
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
187-
"Shape and strides must be present and of equal size for ui64 "
188-
"initialization.");
189-
190169
llvm::SmallVector<int64_t> staticShape;
191170
llvm::SmallVector<int64_t> staticStrides;
192-
llvm::SmallVector<Value> dynamicShape;
193-
llvm::SmallVector<Value> dynamicStrides;
194171

195172
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
196173
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
197174

198175
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
199176
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
200177

178+
if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
179+
auto memrefShape = memrefTy.getShape();
180+
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
181+
182+
// if shape and strides are from Memref, we don't need attributes for them
183+
// to keep the IR print clean.
184+
if (staticShape == memrefShape && staticStrides == memrefStrides) {
185+
staticShapeAttr = DenseI64ArrayAttr();
186+
staticStridesAttr = DenseI64ArrayAttr();
187+
}
188+
}
189+
201190
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
202191
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
203192
staticStridesAttr);
@@ -357,13 +346,10 @@ ParseResult parseOptionalDynamicIndexList(
357346
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
358347
OperandRange values,
359348
DenseI64ArrayAttr integers) {
360-
361-
if (!integers)
349+
if (!integers || integers.empty())
362350
return;
363-
364-
return printDynamicIndexList(printer, op, values, integers,
365-
/*scalableFlags=*/{}, {},
366-
AsmParser::Delimiter::Square);
351+
printDynamicIndexList(printer, op, values, integers,
352+
/*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
367353
}
368354
//===----------------------------------------------------------------------===//
369355
// XeGPU_PrefetchNdOp

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
128128
LogicalResult
129129
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
130130
ConversionPatternRewriter &rewriter) const override {
131+
132+
// Ensure that the op has explicit offsets specified (either dynamic or
133+
// constant).
134+
if (op.getMixedOffsets().empty())
135+
return failure();
136+
131137
Location loc = op.getLoc();
132138
MLIRContext *ctx = op.getContext();
133139
xegpu::TensorDescType tdescTy = op.getType();
@@ -199,6 +205,49 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
199205
}
200206
};
201207

208+
// This pattern transforms the CreateNdDescOp without offsets to create a
209+
// subgroup descriptor from a workgroup descriptor
210+
struct WgToSgCreateNdOpNoOffset
211+
: public OpConversionPattern<xegpu::CreateNdDescOp> {
212+
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
213+
214+
LogicalResult
215+
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
216+
ConversionPatternRewriter &rewriter) const override {
217+
218+
// Check no offsets are specified.
219+
if (!op.getMixedOffsets().empty())
220+
return failure();
221+
222+
Location loc = op.getLoc();
223+
MLIRContext *ctx = op.getContext();
224+
xegpu::TensorDescType tdescTy = op.getType();
225+
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
226+
if (!layout || !layout.isWgLayout())
227+
return failure();
228+
229+
Type elemTy = tdescTy.getElementType();
230+
ArrayRef<int64_t> wgShape = tdescTy.getShape();
231+
232+
SmallVector<int64_t> sgShape;
233+
int count;
234+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
235+
xegpu::TensorDescType newTdescTy =
236+
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
237+
layout.dropSgLayoutAndData());
238+
239+
SmallVector<Value> newCreateNdOps(count);
240+
std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
241+
return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
242+
op.getSource(), op.getMixedSizes(),
243+
op.getMixedStrides());
244+
});
245+
246+
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
247+
return success();
248+
}
249+
};
250+
202251
/// This pattern transforms the LoadNdOp to load subgroup data.
203252
struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
204253
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
@@ -603,11 +652,12 @@ struct UnrealizedConversionCastOpPattern
603652
namespace mlir {
604653
namespace xegpu {
605654
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
606-
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
607-
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
608-
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
609-
WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
610-
patterns.getContext());
655+
patterns
656+
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
657+
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
658+
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
659+
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
660+
patterns.getContext());
611661
}
612662
} // namespace xegpu
613663
} // namespace mlir
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
2+
3+
gpu.module @test_distribution {
4+
// CHECK-LABEL: create_nd_tdesc_no_offset
5+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
6+
gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
7+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
8+
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
9+
// CHECK-NOT: xegpu.create_nd_tdesc
10+
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
11+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
12+
gpu.return
13+
}
14+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
2+
3+
gpu.module @test_distribution {
4+
// CHECK-LABEL: create_nd_tdesc_no_offset
5+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
6+
gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
7+
// CHECK: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
8+
// CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
9+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
10+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
11+
gpu.return
12+
}
13+
14+
// CHECK-LABEL: create_nd_tdesc_with_ptr
15+
// CHECK-SAME: %[[ARG_0:.*]]: ui64
16+
gpu.func @create_nd_tdesc_with_ptr(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
17+
// CHECK: xegpu.create_nd_tdesc %[[ARG_0]], shape : [{{.*}}, {{.*}}], strides : [{{.*}}, {{.*}}] : ui64
18+
// CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
19+
%c1 = arith.constant 1 : index
20+
%tdesc = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides: [%w, %c1] : ui64
21+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
22+
gpu.return
23+
}
24+
}

0 commit comments

Comments
 (0)