Skip to content

Commit 35bdf57

Browse files
committed
Refactor
1 parent 639e997 commit 35bdf57

File tree

2 files changed

+207
-196
lines changed

2 files changed

+207
-196
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 180 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -77,83 +77,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
7777
return std::make_pair(sgShape, count);
7878
}
7979

80-
// Helper function to compute new offsets for subgroup operations.
81-
static SmallVector<SmallVector<OpFoldResult>>
82-
computeSgOffsets(PatternRewriter &rewriter, Location loc,
83-
xegpu::LayoutAttr layout, Value linearSgId,
84-
ArrayRef<int64_t> wgShape, ArrayRef<OpFoldResult> oldOffsets) {
85-
SmallVector<SmallVector<OpFoldResult>> result;
86-
auto maybeTdescOffsets =
87-
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
88-
if (failed(maybeTdescOffsets))
89-
return result;
90-
91-
for (auto &tdescOffsets : *maybeTdescOffsets) {
92-
SmallVector<OpFoldResult> newOffsets;
93-
size_t rank = tdescOffsets.size();
94-
for (size_t i = 0; i < rank; i++) {
95-
size_t idx = oldOffsets.size() - rank + i;
96-
Value add = rewriter.createOrFold<index::AddOp>(
97-
loc, tdescOffsets[i],
98-
getValueOrCreateConstantIndexOp(rewriter, loc, oldOffsets[idx]));
99-
newOffsets.push_back(add);
100-
}
101-
result.push_back(std::move(newOffsets));
102-
}
103-
return result;
104-
}
105-
106-
// Helper struct to hold extracted subgroup info for ops with explicit offsets.
107-
struct SgOffsetInfo {
108-
Location loc;
109-
Value tdesc;
110-
xegpu::TensorDescType tdescTy;
111-
xegpu::LayoutAttr layout;
112-
SmallVector<int64_t> sgShape;
113-
int count;
114-
Value linearSgId;
115-
SmallVector<OpFoldResult> oldOffsets;
116-
};
117-
118-
// Helper function to extract subgroup info for ops with explicit offsets.
119-
// Returns std::nullopt on failure.
120-
template <typename OpTy>
121-
std::optional<SgOffsetInfo>
122-
extractSgOffsetInfo(OpTy op, ConversionPatternRewriter &rewriter) {
123-
124-
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
125-
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
126-
return std::nullopt;
127-
128-
Location loc = op.getLoc();
129-
Value tdesc = op.getTensorDesc();
130-
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
131-
if (!tdescTy)
132-
return std::nullopt;
133-
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
134-
if (!layout)
135-
return std::nullopt;
136-
137-
ArrayRef<int64_t> wgShape = tdescTy.getShape();
138-
SmallVector<int64_t> sgShape;
139-
int count;
140-
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
141-
142-
Value linearSgId =
143-
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
144-
145-
SmallVector<OpFoldResult> oldOffsets;
146-
if (auto constOffsets = op.getConstOffsetsAttr()) {
147-
for (auto attr : constOffsets.asArrayRef())
148-
oldOffsets.push_back(rewriter.getIndexAttr(attr));
149-
}
150-
for (auto v : op.getOffsets())
151-
oldOffsets.push_back(v);
152-
153-
return SgOffsetInfo{loc, tdesc, tdescTy, layout,
154-
sgShape, count, linearSgId, oldOffsets};
155-
}
156-
15780
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
15881
/// from a workgroup descriptor. It replaces the offsets and sizes with
15982
/// appropriate values for the subgroup.
@@ -351,43 +274,6 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
351274
}
352275
};
353276

354-
// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
355-
// data.
356-
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
357-
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
358-
359-
LogicalResult
360-
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
361-
ConversionPatternRewriter &rewriter) const override {
362-
363-
auto infoOpt = extractSgOffsetInfo(op, rewriter);
364-
if (!infoOpt)
365-
return failure();
366-
const auto &info = *infoOpt;
367-
368-
auto sgOffsets =
369-
computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
370-
info.tdescTy.getShape(), info.oldOffsets);
371-
if (sgOffsets.empty())
372-
return failure();
373-
374-
SmallVector<Value> newLoadOps;
375-
auto tdescRange = adaptor.getTensorDesc();
376-
for (auto it : llvm::zip(sgOffsets, tdescRange)) {
377-
VectorType newResTy =
378-
VectorType::get(info.sgShape, info.tdescTy.getElementType());
379-
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
380-
info.loc, newResTy, std::get<1>(it), std::get<0>(it),
381-
/*packed=*/nullptr,
382-
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
383-
op.getL3HintAttr());
384-
newLoadOps.push_back(newLoadOp);
385-
}
386-
rewriter.replaceOpWithMultiple(op, {newLoadOps});
387-
return success();
388-
}
389-
};
390-
391277
/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
392278
/// It creates a StoreNdOp op to store the updated values to the new subgroup
393279
/// src tensor descriptors.
@@ -410,36 +296,192 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
410296
}
411297
};
412298

299+
// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
300+
// data.
301+
// Use a template parameter for the adaptor type
302+
template <typename OpTy, typename AdaptorTy, typename CreateFn>
303+
LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
304+
ConversionPatternRewriter &rewriter,
305+
CreateFn &&createOp) {
306+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
307+
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
308+
return failure();
309+
310+
Location loc = op.getLoc();
311+
Value tdesc = op.getTensorDesc();
312+
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
313+
if (!tdescTy)
314+
return failure();
315+
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
316+
if (!layout)
317+
return failure();
318+
319+
SmallVector<int64_t> sgLayout;
320+
if (auto sgLayoutAttr = layout.getSgLayout())
321+
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
322+
else
323+
return rewriter.notifyMatchFailure(
324+
op, "sgLayout attribute is required in layout");
325+
326+
ArrayRef<int64_t> wgShape = tdescTy.getShape();
327+
SmallVector<int64_t> sgShape;
328+
int count;
329+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
330+
331+
// Get the subgroup ID
332+
Value linearSgId =
333+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
334+
335+
int64_t startOfRange = -1, endOfRange = -1;
336+
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
337+
338+
if (sgIdRangeSpecified) {
339+
int64_t sgCount = endOfRange - startOfRange;
340+
if (computeProduct(sgLayout) != sgCount)
341+
return rewriter.notifyMatchFailure(
342+
op, "sg_layout size must match the sg_id_range");
343+
Value startOfRangeVal =
344+
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
345+
linearSgId =
346+
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
347+
}
348+
349+
auto maybeTdescOffsets =
350+
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
351+
if (failed(maybeTdescOffsets))
352+
return failure();
353+
354+
SmallVector<OpFoldResult> oldOffsets;
355+
if (auto constOffsets = op.getConstOffsetsAttr()) {
356+
for (auto attr : constOffsets.asArrayRef())
357+
oldOffsets.push_back(rewriter.getIndexAttr(attr));
358+
}
359+
for (auto v : op.getOffsets())
360+
oldOffsets.push_back(v);
361+
362+
// Delegate to the operation-specific creation function
363+
return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
364+
rewriter, op);
365+
}
366+
367+
// Usage for LoadNdOp
368+
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
369+
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
370+
LogicalResult matchAndRewrite(
371+
xegpu::LoadNdOp op,
372+
typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
373+
ConversionPatternRewriter &rewriter) const override {
374+
return distributeNdOpWithOffset(
375+
op, adaptor, rewriter,
376+
[](Location loc, SmallVector<int64_t> &sgShape,
377+
ArrayRef<SmallVector<Value>> tdescOffsetsList,
378+
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
379+
ConversionPatternRewriter &rewriter,
380+
xegpu::LoadNdOp &op) -> LogicalResult {
381+
SmallVector<Value> newLoadOps;
382+
for (auto [tdescOffsets, tdesc] :
383+
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
384+
SmallVector<OpFoldResult> newOffsets;
385+
size_t rank = tdescOffsets.size();
386+
for (size_t i = 0; i < rank; i++) {
387+
size_t idx = oldOffsets.size() - rank + i;
388+
Value add = rewriter.createOrFold<index::AddOp>(
389+
loc, tdescOffsets[i],
390+
getValueOrCreateConstantIndexOp(rewriter, loc,
391+
oldOffsets[idx]));
392+
newOffsets.push_back(add);
393+
}
394+
VectorType newResTy = VectorType::get(
395+
sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
396+
.getElementType());
397+
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
398+
loc, newResTy, tdesc, newOffsets,
399+
/*packed=*/nullptr,
400+
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
401+
op.getL3HintAttr());
402+
newLoadOps.push_back(newLoadOp);
403+
}
404+
rewriter.replaceOpWithMultiple(op, {newLoadOps});
405+
return success();
406+
});
407+
}
408+
};
409+
413410
// This pattern transforms the StoreNdOp with explicit offsets to store
414411
// subgroup data.
415412
struct WgToSgStoreNdOpWithOffset
416413
: public OpConversionPattern<xegpu::StoreNdOp> {
417414
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
415+
LogicalResult matchAndRewrite(
416+
xegpu::StoreNdOp op,
417+
typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
418+
ConversionPatternRewriter &rewriter) const override {
419+
return distributeNdOpWithOffset(
420+
op, adaptor, rewriter,
421+
[](Location loc, SmallVector<int64_t> &sgShape,
422+
ArrayRef<SmallVector<Value>> tdescOffsetsList,
423+
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
424+
ConversionPatternRewriter &rewriter,
425+
xegpu::StoreNdOp &op) -> LogicalResult {
426+
for (auto [tdescOffsets, tdesc, value] :
427+
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
428+
adaptor.getValue())) {
429+
SmallVector<OpFoldResult> newOffsets;
430+
size_t rank = tdescOffsets.size();
431+
for (size_t i = 0; i < rank; i++) {
432+
size_t idx = oldOffsets.size() - rank + i;
433+
Value add = rewriter.createOrFold<index::AddOp>(
434+
loc, tdescOffsets[i],
435+
getValueOrCreateConstantIndexOp(rewriter, loc,
436+
oldOffsets[idx]));
437+
newOffsets.push_back(add);
438+
}
439+
rewriter.create<xegpu::StoreNdOp>(
440+
loc, value, tdesc, newOffsets, op.getL1HintAttr(),
441+
op.getL2HintAttr(), op.getL3HintAttr());
442+
}
443+
rewriter.eraseOp(op);
444+
return success();
445+
});
446+
}
447+
};
418448

419-
LogicalResult
420-
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
421-
ConversionPatternRewriter &rewriter) const override {
422-
423-
auto infoOpt = extractSgOffsetInfo(op, rewriter);
424-
if (!infoOpt)
425-
return failure();
426-
const auto &info = *infoOpt;
427-
428-
auto sgOffsets =
429-
computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
430-
info.tdescTy.getShape(), info.oldOffsets);
431-
if (sgOffsets.empty())
432-
return failure();
433-
434-
auto tdescRange = adaptor.getTensorDesc();
435-
auto valueRange = adaptor.getValue();
436-
for (auto it : llvm::zip(sgOffsets, tdescRange, valueRange)) {
437-
rewriter.create<xegpu::StoreNdOp>(
438-
info.loc, std::get<2>(it), std::get<1>(it), std::get<0>(it),
439-
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
440-
}
441-
rewriter.eraseOp(op);
442-
return success();
449+
// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
450+
// subgroup data.
451+
struct WgToSgPrefetchNdOpWithOffset
452+
: public OpConversionPattern<xegpu::PrefetchNdOp> {
453+
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
454+
LogicalResult matchAndRewrite(
455+
xegpu::PrefetchNdOp op,
456+
typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
457+
adaptor,
458+
ConversionPatternRewriter &rewriter) const override {
459+
return distributeNdOpWithOffset(
460+
op, adaptor, rewriter,
461+
[](Location loc, SmallVector<int64_t> &sgShape,
462+
ArrayRef<SmallVector<Value>> tdescOffsetsList,
463+
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
464+
ConversionPatternRewriter &rewriter,
465+
xegpu::PrefetchNdOp &op) -> LogicalResult {
466+
for (auto [tdescOffsets, tdesc] :
467+
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
468+
SmallVector<OpFoldResult> newOffsets;
469+
size_t rank = tdescOffsets.size();
470+
for (size_t i = 0; i < rank; i++) {
471+
size_t idx = oldOffsets.size() - rank + i;
472+
Value add = rewriter.createOrFold<index::AddOp>(
473+
loc, tdescOffsets[i],
474+
getValueOrCreateConstantIndexOp(rewriter, loc,
475+
oldOffsets[idx]));
476+
newOffsets.push_back(add);
477+
}
478+
rewriter.create<xegpu::PrefetchNdOp>(
479+
loc, tdesc, newOffsets, op.getL1HintAttr(), op.getL2HintAttr(),
480+
op.getL3HintAttr());
481+
}
482+
rewriter.eraseOp(op);
483+
return success();
484+
});
443485
}
444486
};
445487

@@ -529,38 +571,6 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
529571
}
530572
};
531573

532-
// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
533-
// subgroup data.
534-
struct WgToSgPrefetchNdOpWithOffset
535-
: public OpConversionPattern<xegpu::PrefetchNdOp> {
536-
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
537-
538-
LogicalResult
539-
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
540-
ConversionPatternRewriter &rewriter) const override {
541-
542-
auto infoOpt = extractSgOffsetInfo(op, rewriter);
543-
if (!infoOpt)
544-
return failure();
545-
const auto &info = *infoOpt;
546-
547-
auto sgOffsets =
548-
computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
549-
info.tdescTy.getShape(), info.oldOffsets);
550-
if (sgOffsets.empty())
551-
return failure();
552-
553-
auto tdescRange = adaptor.getTensorDesc();
554-
for (auto it : llvm::zip(sgOffsets, tdescRange)) {
555-
rewriter.create<xegpu::PrefetchNdOp>(
556-
info.loc, std::get<1>(it), std::get<0>(it), op.getL1HintAttr(),
557-
op.getL2HintAttr(), op.getL3HintAttr());
558-
}
559-
rewriter.eraseOp(op);
560-
return success();
561-
}
562-
};
563-
564574
/// This pattern transforms vector.broadcast ops to work at subgroup level.
565575
struct WgToSgVectorBroadcastOp
566576
: public OpConversionPattern<vector::BroadcastOp> {

0 commit comments

Comments
 (0)