Skip to content

Commit fcbdb91

Browse files
committed
Add pattern for load/store/prefetch nd with offsets
1 parent 9e799d6 commit fcbdb91

File tree

3 files changed

+241
-1
lines changed

3 files changed

+241
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
272272

273273
let builders = [
274274
OpBuilder<(ins "Value": $TensorDesc,
275+
"xegpu::CachePolicyAttr": $l1_hint,
276+
"xegpu::CachePolicyAttr": $l2_hint,
277+
"xegpu::CachePolicyAttr": $l3_hint)>,
278+
OpBuilder<(ins "Value": $TensorDesc,
279+
"ArrayRef<OpFoldResult>": $offsets,
275280
"xegpu::CachePolicyAttr": $l1_hint,
276281
"xegpu::CachePolicyAttr": $l2_hint,
277282
"xegpu::CachePolicyAttr": $l3_hint)>
@@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
348353

349354
let builders = [
350355
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
356+
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
357+
"xegpu::CachePolicyAttr": $l1_hint,
358+
"xegpu::CachePolicyAttr": $l2_hint,
359+
"xegpu::CachePolicyAttr": $l3_hint)>,
360+
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
361+
"ArrayRef<OpFoldResult>": $offsets,
351362
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
352363
"xegpu::CachePolicyAttr": $l1_hint,
353364
"xegpu::CachePolicyAttr": $l2_hint,
@@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
419430
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
420431
"xegpu::CachePolicyAttr": $l1_hint,
421432
"xegpu::CachePolicyAttr": $l2_hint,
422-
"xegpu::CachePolicyAttr": $l3_hint)>
433+
"xegpu::CachePolicyAttr": $l3_hint)>,
434+
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
435+
"ArrayRef<OpFoldResult>": $offsets,
436+
"xegpu::CachePolicyAttr": $l1_hint,
437+
"xegpu::CachePolicyAttr": $l2_hint,
438+
"xegpu::CachePolicyAttr": $l3_hint)>
423439
];
424440

425441

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
364364
l1_hint, l2_hint, l3_hint);
365365
}
366366

367+
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
368+
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
369+
xegpu::CachePolicyAttr l1_hint,
370+
xegpu::CachePolicyAttr l2_hint,
371+
xegpu::CachePolicyAttr l3_hint) {
372+
SmallVector<Value> dynamicOffsets;
373+
SmallVector<int64_t> staticOffsets;
374+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
375+
376+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
377+
378+
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
379+
l2_hint, l3_hint);
380+
}
381+
367382
LogicalResult PrefetchNdOp::verify() {
368383
auto tdescTy = getTensorDescType();
369384
if (tdescTy.isScattered())
@@ -406,6 +421,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
406421
l3_hint);
407422
}
408423

424+
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
425+
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
426+
UnitAttr packed, DenseI64ArrayAttr transpose,
427+
xegpu::CachePolicyAttr l1_hint,
428+
xegpu::CachePolicyAttr l2_hint,
429+
xegpu::CachePolicyAttr l3_hint) {
430+
SmallVector<Value> dynamicOffsets;
431+
SmallVector<int64_t> staticOffsets;
432+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
433+
434+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
435+
436+
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
437+
packed, transpose, l1_hint, l2_hint, l3_hint);
438+
}
439+
409440
LogicalResult LoadNdOp::verify() {
410441
auto tdescTy = getTensorDescType();
411442
auto valueTy = getType();
@@ -512,6 +543,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
512543
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
513544
}
514545

546+
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
547+
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
548+
xegpu::CachePolicyAttr l1_hint,
549+
xegpu::CachePolicyAttr l2_hint,
550+
xegpu::CachePolicyAttr l3_hint) {
551+
SmallVector<Value> dynamicOffsets;
552+
SmallVector<int64_t> staticOffsets;
553+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
554+
555+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
556+
557+
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
558+
l1_hint, l2_hint, l3_hint);
559+
}
560+
515561
LogicalResult StoreNdOp::verify() {
516562
auto dstTy = getTensorDescType(); // Tile
517563
auto valTy = getValueType(); // Vector

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,82 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
7777
return std::make_pair(sgShape, count);
7878
}
7979

80+
// Helper function to compute new offsets for subgroup operations.
81+
static SmallVector<SmallVector<OpFoldResult>>
82+
computeSgOffsets(PatternRewriter &rewriter, Location loc,
83+
xegpu::LayoutAttr layout, Value linearSgId,
84+
ArrayRef<int64_t> wgShape, ArrayRef<OpFoldResult> oldOffsets) {
85+
SmallVector<SmallVector<OpFoldResult>> result;
86+
auto maybeTdescOffsets =
87+
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
88+
if (failed(maybeTdescOffsets))
89+
return result;
90+
91+
for (auto &tdescOffsets : *maybeTdescOffsets) {
92+
SmallVector<OpFoldResult> newOffsets;
93+
size_t rank = tdescOffsets.size();
94+
for (size_t i = 0; i < rank; i++) {
95+
size_t idx = oldOffsets.size() - rank + i;
96+
Value add = rewriter.createOrFold<index::AddOp>(
97+
loc, tdescOffsets[i],
98+
getValueOrCreateConstantIndexOp(rewriter, loc, oldOffsets[idx]));
99+
newOffsets.push_back(add);
100+
}
101+
result.push_back(std::move(newOffsets));
102+
}
103+
return result;
104+
}
105+
106+
// Helper struct to hold extracted subgroup info for ops with explicit offsets.
107+
struct SgOffsetInfo {
108+
Location loc;
109+
Value tdesc;
110+
xegpu::TensorDescType tdescTy;
111+
xegpu::LayoutAttr layout;
112+
SmallVector<int64_t> sgShape;
113+
int count;
114+
Value linearSgId;
115+
SmallVector<OpFoldResult> oldOffsets;
116+
};
117+
118+
// Helper function to extract subgroup info for ops with explicit offsets.
119+
// Returns std::nullopt on failure.
120+
template <typename OpTy>
121+
std::optional<SgOffsetInfo>
122+
extractSgOffsetInfo(OpTy op, ConversionPatternRewriter &rewriter) {
123+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
124+
if (offsetSize == 0)
125+
return std::nullopt;
126+
127+
Location loc = op.getLoc();
128+
Value tdesc = op.getTensorDesc();
129+
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
130+
if (!tdescTy)
131+
return std::nullopt;
132+
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
133+
if (!layout)
134+
return std::nullopt;
135+
136+
ArrayRef<int64_t> wgShape = tdescTy.getShape();
137+
SmallVector<int64_t> sgShape;
138+
int count;
139+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
140+
141+
Value linearSgId =
142+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
143+
144+
SmallVector<OpFoldResult> oldOffsets;
145+
if (auto constOffsets = op.getConstOffsetsAttr()) {
146+
for (auto attr : constOffsets.asArrayRef())
147+
oldOffsets.push_back(rewriter.getIndexAttr(attr));
148+
}
149+
for (auto v : op.getOffsets())
150+
oldOffsets.push_back(v);
151+
152+
return SgOffsetInfo{loc, tdesc, tdescTy, layout,
153+
sgShape, count, linearSgId, oldOffsets};
154+
}
155+
80156
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
81157
/// from a workgroup descriptor. It replaces the offsets and sizes with
82158
/// appropriate values for the subgroup.
@@ -275,6 +351,43 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
275351
}
276352
};
277353

354+
// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
355+
// data.
356+
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
357+
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
358+
359+
LogicalResult
360+
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
361+
ConversionPatternRewriter &rewriter) const override {
362+
363+
auto infoOpt = extractSgOffsetInfo(op, rewriter);
364+
if (!infoOpt)
365+
return failure();
366+
const auto &info = *infoOpt;
367+
368+
auto sgOffsets =
369+
computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
370+
info.tdescTy.getShape(), info.oldOffsets);
371+
if (sgOffsets.empty())
372+
return failure();
373+
374+
SmallVector<Value> newLoadOps;
375+
auto tdescRange = adaptor.getTensorDesc();
376+
for (auto it : llvm::zip(sgOffsets, tdescRange)) {
377+
VectorType newResTy =
378+
VectorType::get(info.sgShape, info.tdescTy.getElementType());
379+
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
380+
info.loc, newResTy, std::get<1>(it), std::get<0>(it),
381+
/*packed=*/nullptr,
382+
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
383+
op.getL3HintAttr());
384+
newLoadOps.push_back(newLoadOp);
385+
}
386+
rewriter.replaceOpWithMultiple(op, {newLoadOps});
387+
return success();
388+
}
389+
};
390+
278391
/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
279392
/// It creates a StoreNdOp op to store the updated values to the new subgroup
280393
/// src tensor descriptors.
@@ -297,6 +410,39 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
297410
}
298411
};
299412

413+
// This pattern transforms the StoreNdOp with explicit offsets to store
414+
// subgroup data.
415+
struct WgToSgStoreNdOpWithOffset
416+
: public OpConversionPattern<xegpu::StoreNdOp> {
417+
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
418+
419+
LogicalResult
420+
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
421+
ConversionPatternRewriter &rewriter) const override {
422+
423+
auto infoOpt = extractSgOffsetInfo(op, rewriter);
424+
if (!infoOpt)
425+
return failure();
426+
const auto &info = *infoOpt;
427+
428+
auto sgOffsets =
429+
computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
430+
info.tdescTy.getShape(), info.oldOffsets);
431+
if (sgOffsets.empty())
432+
return failure();
433+
434+
auto tdescRange = adaptor.getTensorDesc();
435+
auto valueRange = adaptor.getValue();
436+
for (auto it : llvm::zip(sgOffsets, tdescRange, valueRange)) {
437+
rewriter.create<xegpu::StoreNdOp>(
438+
info.loc, std::get<2>(it), std::get<1>(it), std::get<0>(it),
439+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
440+
}
441+
rewriter.eraseOp(op);
442+
return success();
443+
}
444+
};
445+
300446
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
301447
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
302448
/// offsets of the new subgroup src tensor descriptors.
@@ -383,6 +529,38 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
383529
}
384530
};
385531

532+
// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
533+
// subgroup data.
534+
struct WgToSgPrefetchNdOpWithOffset
535+
: public OpConversionPattern<xegpu::PrefetchNdOp> {
536+
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
537+
538+
LogicalResult
539+
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
540+
ConversionPatternRewriter &rewriter) const override {
541+
542+
auto infoOpt = extractSgOffsetInfo(op, rewriter);
543+
if (!infoOpt)
544+
return failure();
545+
const auto &info = *infoOpt;
546+
547+
auto sgOffsets =
548+
computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
549+
info.tdescTy.getShape(), info.oldOffsets);
550+
if (sgOffsets.empty())
551+
return failure();
552+
553+
auto tdescRange = adaptor.getTensorDesc();
554+
for (auto it : llvm::zip(sgOffsets, tdescRange)) {
555+
rewriter.create<xegpu::PrefetchNdOp>(
556+
info.loc, std::get<1>(it), std::get<0>(it), op.getL1HintAttr(),
557+
op.getL2HintAttr(), op.getL3HintAttr());
558+
}
559+
rewriter.eraseOp(op);
560+
return success();
561+
}
562+
};
563+
386564
/// This pattern transforms vector.broadcast ops to work at subgroup level.
387565
struct WgToSgVectorBroadcastOp
388566
: public OpConversionPattern<vector::BroadcastOp> {

0 commit comments

Comments
 (0)