Skip to content

Commit 4a9d038

Browse files
authored
[MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (llvm#153432)
This PR adds pattern to distribute the load/store/prefetch nd ops with offsets from workgroup to subgroup IR. This PR is part of the transition to move offsets from create_nd to load/store/prefetch nd ops. Create_nd PR : llvm#152351
1 parent d6e0922 commit 4a9d038

File tree

5 files changed

+586
-11
lines changed

5 files changed

+586
-11
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
272272

273273
let builders = [
274274
OpBuilder<(ins "Value": $TensorDesc,
275+
"xegpu::CachePolicyAttr": $l1_hint,
276+
"xegpu::CachePolicyAttr": $l2_hint,
277+
"xegpu::CachePolicyAttr": $l3_hint)>,
278+
OpBuilder<(ins "Value": $TensorDesc,
279+
"ArrayRef<OpFoldResult>": $offsets,
275280
"xegpu::CachePolicyAttr": $l1_hint,
276281
"xegpu::CachePolicyAttr": $l2_hint,
277282
"xegpu::CachePolicyAttr": $l3_hint)>
@@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
348353

349354
let builders = [
350355
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
356+
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
357+
"xegpu::CachePolicyAttr": $l1_hint,
358+
"xegpu::CachePolicyAttr": $l2_hint,
359+
"xegpu::CachePolicyAttr": $l3_hint)>,
360+
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
361+
"ArrayRef<OpFoldResult>": $offsets,
351362
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
352363
"xegpu::CachePolicyAttr": $l1_hint,
353364
"xegpu::CachePolicyAttr": $l2_hint,
@@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
419430
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
420431
"xegpu::CachePolicyAttr": $l1_hint,
421432
"xegpu::CachePolicyAttr": $l2_hint,
422-
"xegpu::CachePolicyAttr": $l3_hint)>
433+
"xegpu::CachePolicyAttr": $l3_hint)>,
434+
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
435+
"ArrayRef<OpFoldResult>": $offsets,
436+
"xegpu::CachePolicyAttr": $l1_hint,
437+
"xegpu::CachePolicyAttr": $l2_hint,
438+
"xegpu::CachePolicyAttr": $l3_hint)>
423439
];
424440

425441

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
385385
l1_hint, l2_hint, l3_hint);
386386
}
387387

388+
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
389+
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
390+
xegpu::CachePolicyAttr l1_hint,
391+
xegpu::CachePolicyAttr l2_hint,
392+
xegpu::CachePolicyAttr l3_hint) {
393+
SmallVector<Value> dynamicOffsets;
394+
SmallVector<int64_t> staticOffsets;
395+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
396+
397+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
398+
399+
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
400+
l2_hint, l3_hint);
401+
}
402+
388403
LogicalResult PrefetchNdOp::verify() {
389404
auto tdescTy = getTensorDescType();
390405
if (tdescTy.isScattered())
@@ -427,6 +442,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
427442
l3_hint);
428443
}
429444

445+
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
446+
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
447+
UnitAttr packed, DenseI64ArrayAttr transpose,
448+
xegpu::CachePolicyAttr l1_hint,
449+
xegpu::CachePolicyAttr l2_hint,
450+
xegpu::CachePolicyAttr l3_hint) {
451+
SmallVector<Value> dynamicOffsets;
452+
SmallVector<int64_t> staticOffsets;
453+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
454+
455+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
456+
457+
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
458+
packed, transpose, l1_hint, l2_hint, l3_hint);
459+
}
460+
430461
LogicalResult LoadNdOp::verify() {
431462
auto tdescTy = getTensorDescType();
432463
auto valueTy = getType();
@@ -533,6 +564,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
533564
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
534565
}
535566

567+
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
568+
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
569+
xegpu::CachePolicyAttr l1_hint,
570+
xegpu::CachePolicyAttr l2_hint,
571+
xegpu::CachePolicyAttr l3_hint) {
572+
SmallVector<Value> dynamicOffsets;
573+
SmallVector<int64_t> staticOffsets;
574+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
575+
576+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
577+
578+
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
579+
l1_hint, l2_hint, l3_hint);
580+
}
581+
536582
LogicalResult StoreNdOp::verify() {
537583
auto dstTy = getTensorDescType(); // Tile
538584
auto valTy = getValueType(); // Vector

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 209 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
182182
layout.dropSgLayoutAndData());
183183

184184
SmallVector<Value> newCreateNdOps;
185-
SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
185+
SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
186186

187187
for (auto tdescOffsets : *maybeTdescOffsets) {
188188
SmallVector<OpFoldResult> sgOffsets;
189189
size_t rank = tdescOffsets.size();
190190
for (size_t i = 0; i < rank; i++) {
191-
size_t idx = wgOffsets.size() - rank + i;
191+
size_t idx = origOffsets.size() - rank + i;
192192
Value add = rewriter.createOrFold<index::AddOp>(
193193
loc, tdescOffsets[i],
194-
getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
194+
getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
195195
sgOffsets.push_back(add);
196196
}
197197

@@ -296,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296
}
297297
};
298298

299+
// Utility function to compute global offsets for subgroup operations.
300+
// Returns a vector of new offsets for each subgroup, given the original op's
301+
// offsets and subgroup relative offsets.
302+
static SmallVector<SmallVector<OpFoldResult>>
303+
computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304+
ArrayRef<OpFoldResult> origOffsets,
305+
ConversionPatternRewriter &rewriter) {
306+
SmallVector<SmallVector<OpFoldResult>> finalOffsets;
307+
Location loc = op->getLoc();
308+
for (const auto &sgOffsets : sgOffsetsList) {
309+
SmallVector<OpFoldResult> newOffsets;
310+
size_t rank = sgOffsets.size();
311+
for (size_t i = 0; i < rank; i++) {
312+
size_t idx = origOffsets.size() - rank + i;
313+
Value add = rewriter.createOrFold<index::AddOp>(
314+
loc, sgOffsets[i],
315+
getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
316+
newOffsets.push_back(add);
317+
}
318+
finalOffsets.push_back(std::move(newOffsets));
319+
}
320+
return finalOffsets;
321+
}
322+
323+
// Utility function to get sgShape, sgOffsetList for a given
324+
// op.
325+
template <typename OpTy, typename AdaptorTy>
326+
LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
327+
ConversionPatternRewriter &rewriter,
328+
SmallVector<int64_t> &sgShape,
329+
SmallVector<SmallVector<Value>> &sgOffsetList) {
330+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
331+
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
332+
return failure();
333+
334+
Location loc = op.getLoc();
335+
Value tdesc = op.getTensorDesc();
336+
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
337+
if (!tdescTy)
338+
return failure();
339+
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
340+
if (!layout)
341+
return failure();
342+
343+
SmallVector<int64_t> sgLayout;
344+
auto sgLayoutAttr = layout.getSgLayout();
345+
if (!sgLayoutAttr)
346+
return rewriter.notifyMatchFailure(
347+
op, "sgLayout attribute is required in layout");
348+
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
349+
350+
ArrayRef<int64_t> wgShape = tdescTy.getShape();
351+
int count;
352+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
353+
354+
// Get the subgroup ID
355+
Value linearSgId =
356+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
357+
358+
int64_t startOfRange = -1, endOfRange = -1;
359+
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
360+
361+
if (sgIdRangeSpecified) {
362+
int64_t sgCount = endOfRange - startOfRange;
363+
if (computeProduct(sgLayout) != sgCount)
364+
return rewriter.notifyMatchFailure(
365+
op, "sg_layout size must match the sg_id_range");
366+
Value startOfRangeVal =
367+
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
368+
linearSgId =
369+
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
370+
}
371+
372+
auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
373+
if (failed(sgOffsets))
374+
return failure();
375+
376+
sgOffsetList = *sgOffsets;
377+
return success();
378+
}
379+
380+
template <typename OpTy>
381+
SmallVector<OpFoldResult> getOffsets(OpTy op,
382+
ConversionPatternRewriter &rewriter) {
383+
SmallVector<OpFoldResult> origOffsets;
384+
if (auto constOffsets = op.getConstOffsetsAttr()) {
385+
for (auto attr : constOffsets.asArrayRef())
386+
origOffsets.push_back(rewriter.getIndexAttr(attr));
387+
}
388+
for (auto v : op.getOffsets())
389+
origOffsets.push_back(v);
390+
return origOffsets;
391+
}
392+
393+
// This pattern transforms the LoadNdOp with explicit offsets to load
394+
// subgroup data.
395+
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
396+
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
397+
LogicalResult
398+
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
399+
ConversionPatternRewriter &rewriter) const override {
400+
401+
SmallVector<int64_t> sgShape;
402+
SmallVector<SmallVector<Value>> sgOffsetList;
403+
404+
// Do the distribution from workgroup to subgroup and get subgroup offsets
405+
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
406+
return failure();
407+
408+
// Get the original workgroup offsets
409+
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
410+
411+
// Calculate the final offsets for each subgroup
412+
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
413+
414+
SmallVector<Value> newLoadOps;
415+
for (auto [offsets, tdesc] :
416+
llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
417+
VectorType newResTy = VectorType::get(
418+
sgShape,
419+
dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
420+
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
421+
op.getLoc(), newResTy, tdesc, offsets,
422+
/*packed=*/nullptr,
423+
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
424+
op.getL3HintAttr());
425+
newLoadOps.push_back(newLoadOp);
426+
}
427+
rewriter.replaceOpWithMultiple(op, {newLoadOps});
428+
return success();
429+
}
430+
};
431+
432+
// This pattern transforms the StoreNdOp with explicit offsets to store
433+
// subgroup data.
434+
struct WgToSgStoreNdOpWithOffset
435+
: public OpConversionPattern<xegpu::StoreNdOp> {
436+
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
437+
LogicalResult
438+
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
439+
ConversionPatternRewriter &rewriter) const override {
440+
441+
SmallVector<int64_t> sgShape;
442+
SmallVector<SmallVector<Value>> sgOffsetList;
443+
444+
// Do the distribution from workgroup to subgroup and get subgroup offsets
445+
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
446+
return failure();
447+
448+
// Get the original workgroup offsets
449+
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
450+
451+
// Calculate the final offsets for each subgroup
452+
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
453+
454+
for (auto [offsets, tdesc, value] :
455+
llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
456+
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
457+
op.getL1HintAttr(), op.getL2HintAttr(),
458+
op.getL3HintAttr());
459+
}
460+
rewriter.eraseOp(op);
461+
return success();
462+
}
463+
};
464+
465+
// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
466+
// subgroup data.
467+
struct WgToSgPrefetchNdOpWithOffset
468+
: public OpConversionPattern<xegpu::PrefetchNdOp> {
469+
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
470+
LogicalResult
471+
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
472+
ConversionPatternRewriter &rewriter) const override {
473+
474+
SmallVector<int64_t> sgShape;
475+
SmallVector<SmallVector<Value>> sgOffsetList;
476+
477+
// Do the distribution from workgroup to subgroup and get subgroup offsets
478+
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
479+
return failure();
480+
481+
// Get the original workgroup offsets
482+
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
483+
484+
// Calculate the final offsets for each subgroup
485+
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
486+
487+
for (auto [offsets, tdesc] :
488+
llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
489+
rewriter.create<xegpu::PrefetchNdOp>(
490+
op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
491+
op.getL3HintAttr());
492+
}
493+
rewriter.eraseOp(op);
494+
return success();
495+
}
496+
};
497+
299498
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
300499
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
301500
/// offsets of the new subgroup src tensor descriptors.
@@ -690,12 +889,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
690889
namespace mlir {
691890
namespace xegpu {
692891
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
693-
patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
694-
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
695-
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
696-
WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
697-
WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
698-
patterns.getContext());
892+
patterns
893+
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
894+
WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
895+
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
896+
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
897+
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
898+
WgToSgArithConstantOp>(patterns.getContext());
699899
}
700900
} // namespace xegpu
701901
} // namespace mlir

0 commit comments

Comments
 (0)