Skip to content

Commit 65b5dbd

Browse files
committed
refactor ConvertLayoutPattern for wg to sg.
1 parent c416cec commit 65b5dbd

File tree

5 files changed

+70
-160
lines changed

5 files changed

+70
-160
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,13 @@ LogicalResult TensorDescType::verify(
313313
if (rank != 1 && rank != 2)
314314
return emitError() << "expected 1D or 2D tensor";
315315

316-
// auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
317-
// if (blockAttr) {
318-
// MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
319-
// if (rank == 2 && memorySpaceAttr &&
320-
// memorySpaceAttr.getValue() == MemorySpace::SLM)
321-
// return emitError() << "SLM is not supported for 2D block tensor";
322-
// }
316+
auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
317+
if (blockAttr) {
318+
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
319+
if (rank == 2 && memorySpaceAttr &&
320+
memorySpaceAttr.getValue() == MemorySpace::SLM)
321+
return emitError() << "SLM is not supported for 2D block tensor";
322+
}
323323

324324
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
325325
unsigned bitWidth = elementType.getIntOrFloatBitWidth();

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ LogicalResult ConvertLayoutOp::verify() {
616616
if (!resLayout)
617617
return emitOpError("expected target layout.");
618618

619-
// both srcMap and resMap should be WgLayout or SgLayout at the same time.
619+
// both input and target layouts should be WgLayout or SgLayout at the same time.
620620
if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
621621
(!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
622622
return emitOpError("expected input layout and target layout be WgLayout or "
@@ -644,10 +644,11 @@ struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
644644
using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
645645
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
646646
PatternRewriter &rewriter) const override {
647-
if (op.getInputLayout() != op.getTargetLayout())
648-
return failure();
649-
rewriter.replaceOp(op, op.getSource());
650-
return success();
647+
if (op.getInputLayout() == op.getTargetLayout()) {
648+
rewriter.replaceOp(op, op.getSource());
649+
return success();
650+
}
651+
return failure();
651652
}
652653
};
653654

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 46 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -57,39 +57,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
5757
return std::make_pair(sgShape, count);
5858
}
5959

60-
// Calculate offset for each subgroup
61-
static SmallVector<OpFoldResult>
62-
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
63-
const SmallVector<OpFoldResult> &originalOffsets,
64-
const SmallVector<Value> &localOffset,
65-
const SmallVector<int64_t> &distUnitBaseAddr,
66-
const SmallVector<int64_t> &distUnitShape) {
67-
assert(localOffset.size() == distUnitBaseAddr.size() &&
68-
"localOffset and distUnitBaseAddr must have the same rank");
69-
70-
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
71-
originalOffsets.end());
72-
size_t rank = localOffset.size();
73-
for (size_t i = 0; i < rank; ++i) {
74-
size_t dimIdx = originalOffsets.size() - rank + i;
75-
Value constOffset =
76-
rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
77-
Value offset =
78-
rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
79-
Value modValue =
80-
rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
81-
Value offsetMod =
82-
rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
83-
Value origOffset =
84-
getValueOrCreateConstantIndexOp(rewriter, loc, originalOffsets[dimIdx]);
85-
Value globalOffset =
86-
rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
87-
globalOffsets[dimIdx] = globalOffset;
88-
}
89-
90-
return globalOffsets;
91-
}
92-
9360
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
9461
/// from a workgroup descriptor. It replaces the offsets and sizes with
9562
/// appropriate values for the subgroup.
@@ -138,6 +105,39 @@ calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
138105
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
139106
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
140107

108+
// Calculate offset for each subgroup
109+
static SmallVector<OpFoldResult>
110+
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
111+
const SmallVector<OpFoldResult> &originalOffsets,
112+
const SmallVector<Value> &localOffset,
113+
const SmallVector<int64_t> &distUnitBaseAddr,
114+
const SmallVector<int64_t> &distUnitShape) {
115+
assert(localOffset.size() == distUnitBaseAddr.size() &&
116+
"localOffset and distUnitBaseAddr must have the same rank");
117+
118+
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
119+
originalOffsets.end());
120+
size_t rank = localOffset.size();
121+
for (size_t i = 0; i < rank; ++i) {
122+
size_t dimIdx = originalOffsets.size() - rank + i;
123+
Value constOffset =
124+
rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
125+
Value offset =
126+
rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
127+
Value modValue =
128+
rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
129+
Value offsetMod =
130+
rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
131+
Value origOffset =
132+
getValueOrCreateConstantIndexOp(rewriter, loc, originalOffsets[dimIdx]);
133+
Value globalOffset =
134+
rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
135+
globalOffsets[dimIdx] = globalOffset;
136+
}
137+
138+
return globalOffsets;
139+
}
140+
141141
LogicalResult
142142
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
143143
ConversionPatternRewriter &rewriter) const override {
@@ -390,21 +390,6 @@ struct WgToSgElementwiseOp : public ConversionPattern {
390390
}
391391
};
392392

393-
// based on the size of the given vector type
394-
static TypedValue<MemRefType>
395-
allocateSLMBuffer(ConversionPatternRewriter &rewriter, Location loc,
396-
VectorType type) {
397-
int64_t bits = type.getElementType().getIntOrFloatBitWidth();
398-
int64_t slmSizeInBytes = type.getNumElements() * bits / 8;
399-
auto slmTy = MemRefType::get(slmSizeInBytes, rewriter.getI8Type(), {}, 3);
400-
auto slm = rewriter.create<memref::AllocOp>(loc, slmTy);
401-
auto viewTy = MemRefType::get(type.getShape(), type.getElementType(), {}, 3);
402-
auto view = rewriter.create<memref::ViewOp>(
403-
loc, viewTy, slm, rewriter.create<arith::ConstantIndexOp>(loc, 0),
404-
ValueRange());
405-
return view;
406-
}
407-
408393
struct WgToSgConvertLayoutOp
409394
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
410395
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
@@ -418,115 +403,29 @@ struct WgToSgConvertLayoutOp
418403
return rewriter.notifyMatchFailure(
419404
op, "Input and target layouts must have subgroup layout");
420405

421-
// initialize values with the source values
422-
SmallVector<Value> values(adaptor.getSource());
423-
424-
Location loc = op.getLoc();
425-
MLIRContext *ctx = op.getContext();
426-
VectorType type = op.getResult().getType();
427-
ArrayRef<int64_t> shape = type.getShape();
428-
429406
DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
430407
DenseI32ArrayAttr inputSgData = input.getSgData();
431408
DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
432409
DenseI32ArrayAttr targetSgData = target.getSgData();
433410

434-
// we only need SLM support when input and target layouts are different
435-
if (inputSgLayout != targetSgLayout || inputSgData != targetSgData) {
436-
values.clear();
437-
rewriter.setInsertionPoint(op);
438-
TypedValue<MemRefType> slmBuffer = allocateSLMBuffer(rewriter, loc, type);
439-
440-
auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(
441-
loc, rewriter.getIndexType(), nullptr);
442-
443-
{ // store to slm buffer
444-
SmallVector<int64_t> sgLayout =
445-
llvm::to_vector_of<int64_t>(input.getSgLayout().asArrayRef());
446-
SmallVector<int64_t> sgShape = getSgShapeAndCount(shape, input).first;
447-
auto delinearized = affine::delinearizeIndex(
448-
rewriter, loc, linearSgId, getAsIndexOpFoldResult(ctx, sgLayout));
449-
if (failed(delinearized))
450-
return rewriter.notifyMatchFailure(op, "Failed to delinearize sgId");
451-
SmallVector<Value> sgIds = *delinearized;
452-
453-
SmallVector<int64_t> distUnitShape(sgLayout.size());
454-
SmallVector<Value> localOffset(sgLayout.size());
455-
for (size_t i = 0; i < sgLayout.size(); i++) {
456-
distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], shape[i]);
457-
localOffset[i] = rewriter.createOrFold<index::MulOp>(
458-
loc, sgIds[i],
459-
rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]));
460-
}
461-
462-
auto tdescTy = xegpu::TensorDescType::get(
463-
sgShape, type.getElementType(), 1, false, xegpu::MemorySpace::SLM,
464-
input.dropSgLayoutAndData());
465-
466-
SmallVector<OpFoldResult> zeros = getAsIndexOpFoldResult(
467-
ctx, SmallVector<int64_t>(sgLayout.size(), 0));
468-
for (auto [data, baseOffsets] :
469-
llvm::zip_equal(adaptor.getSource(),
470-
StaticTileOffsetRange(shape, distUnitShape))) {
471-
SmallVector<OpFoldResult> offsets = calculateGlobalOffsets(
472-
rewriter, loc, zeros, localOffset, baseOffsets, distUnitShape);
473-
auto tdesc = rewriter.create<xegpu::CreateNdDescOp>(
474-
loc, tdescTy, slmBuffer, offsets);
475-
rewriter.create<xegpu::StoreNdOp>(loc, data, tdesc, nullptr, nullptr,
476-
nullptr);
477-
}
478-
}
479-
480-
rewriter.create<gpu::BarrierOp>(loc);
481-
482-
{ // load from SLM
483-
SmallVector<int64_t> sgLayout =
484-
llvm::to_vector_of<int64_t>(target.getSgLayout().asArrayRef());
485-
SmallVector<int64_t> sgShape = getSgShapeAndCount(shape, target).first;
486-
auto delinearized = affine::delinearizeIndex(
487-
rewriter, loc, linearSgId, getAsIndexOpFoldResult(ctx, sgLayout));
488-
if (failed(delinearized))
489-
return rewriter.notifyMatchFailure(op, "Failed to delinearize sgId");
490-
SmallVector<Value> sgIds = *delinearized;
491-
492-
SmallVector<int64_t> distUnitShape(sgLayout.size());
493-
SmallVector<Value> localOffset(sgLayout.size());
494-
for (size_t i = 0; i < sgLayout.size(); i++) {
495-
distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], shape[i]);
496-
localOffset[i] = rewriter.createOrFold<index::MulOp>(
497-
loc, sgIds[i],
498-
rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]));
499-
}
500-
501-
auto tdescTy = xegpu::TensorDescType::get(
502-
sgShape, type.getElementType(), 1, false, xegpu::MemorySpace::SLM,
503-
target.dropSgLayoutAndData());
504-
auto valueTy = VectorType::get(sgShape, type.getElementType());
505-
506-
SmallVector<OpFoldResult> zeros = getAsIndexOpFoldResult(
507-
ctx, SmallVector<int64_t>(sgLayout.size(), 0));
508-
for (auto baseOffsets : StaticTileOffsetRange(shape, distUnitShape)) {
509-
SmallVector<OpFoldResult> offsets = calculateGlobalOffsets(
510-
rewriter, loc, zeros, localOffset, baseOffsets, distUnitShape);
511-
auto tdesc = rewriter.create<xegpu::CreateNdDescOp>(
512-
loc, tdescTy, slmBuffer, offsets);
513-
auto newOp = rewriter.create<xegpu::LoadNdOp>(
514-
loc, TypeRange({valueTy}), ValueRange({tdesc}));
515-
values.push_back(newOp);
516-
}
517-
}
518-
}
411+
// TODO: currently we only support for optimal case, where input and
412+
// output has the same sg_layout and sg_data, so SLM is not involved.
413+
if (inputSgLayout != targetSgLayout || inputSgData != targetSgData)
414+
return failure();
519415

520416
input = input.dropSgLayoutAndData();
521417
target = target.dropSgLayoutAndData();
522418

523-
SmallVector<Value> newOps;
524-
for (auto src : values) {
525-
auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
526-
op.getLoc(), src.getType(), src, input, target);
527-
newOps.push_back(newOp);
419+
SmallVector<Value> newOps(adaptor.getSource());
420+
421+
if (input && target) {
422+
for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
423+
auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
424+
op.getLoc(), src.getType(), src, input, target);
425+
newOps[i] = newOp;
426+
}
528427
}
529-
rewriter.replaceOpWithMultiple(op, newOps);
428+
rewriter.replaceOpWithMultiple(op, {newOps});
530429
return success();
531430
}
532431
};

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
124124
Operation *defOp = result.getDefiningOp();
125125
assert(defOp && "result must have a defining op");
126126

127-
// For ConvertLayoutOp, the layout is stored in the tensor descriptor
127+
// For ConvertLayoutOp, the layout is stored in the targetLayoutAttr
128128
if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
129129
return convertOp.getTargetLayoutAttr();
130130

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,4 +198,14 @@ gpu.module @test_round_robin_assignment {
198198
gpu.return
199199
}
200200

201+
gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
202+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
203+
//CHECK-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
204+
//CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
205+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
206+
%2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>,
207+
target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
208+
gpu.return
209+
}
210+
201211
}

0 commit comments

Comments
 (0)