Skip to content

Commit 1c06920

Browse files
committed
fix
1 parent f6cd50a commit 1c06920

File tree

1 file changed

+116
-115
lines changed

1 file changed

+116
-115
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp

Lines changed: 116 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/Utils/Utils.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15+
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1718
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -66,12 +67,12 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
6667
// /// -> !xegpu.tensor_desc<4x1xf32>
6768
// ///
6869
// /// ```
69-
// struct WarpOpTensorDescOp final
70-
// : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
71-
// using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
72-
// LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
73-
// PatternRewriter &rewriter) const override;
74-
// };
70+
struct WarpOpTensorDescOp final
71+
: public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
72+
using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
73+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
74+
PatternRewriter &rewriter) const override;
75+
};
7576

7677
// /// Sink a store_nd feeding into vector.yield op for the enclosing
7778
// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
@@ -159,32 +160,32 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
159160
// return newVectorType;
160161
// }
161162

162-
// FailureOr<xegpu::TensorDescType>
163-
// getDistributedTensorDescType(xegpu::TensorDescType originalT,
164-
// xegpu::SGMapAttr sgMap,
165-
// xegpu::MemorySpace memSpace) {
166-
// llvm::SmallVector<int64_t, 2> distributedShape;
167-
// auto layout = sgMap.getWiLayout();
168-
// auto shape = originalT.getShape();
169-
// for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
170-
// if (!divisible(APInt(64, o), APInt(64, l)))
171-
// return failure();
172-
// distributedShape.push_back(o / l);
173-
// }
174-
// xegpu::TensorDescType distributedDescType;
175-
// if (originalT.isScattered()) {
176-
177-
// distributedDescType = xegpu::TensorDescType::get(
178-
// distributedShape, originalT.getElementType(), originalT.getChunkSize(),
179-
// originalT.getMemorySpace(), originalT.getSGMapAttr());
180-
// } else {
181-
// distributedDescType = xegpu::TensorDescType::get(
182-
// distributedShape, originalT.getElementType(),
183-
// originalT.getBoundaryCheck(), originalT.getArrayLength(),
184-
// originalT.getMemorySpace(), originalT.getSGMapAttr());
185-
// }
186-
// return distributedDescType;
187-
// }
163+
FailureOr<xegpu::TensorDescType>
164+
getDistributedTensorDescType(xegpu::TensorDescType originalT,
165+
xegpu::SGMapAttr sgMap,
166+
xegpu::MemorySpace memSpace) {
167+
llvm::SmallVector<int64_t, 2> distributedShape;
168+
auto layout = sgMap.getWiLayout();
169+
auto shape = originalT.getShape();
170+
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
171+
if (!divisible(APInt(64, o), APInt(64, l)))
172+
return failure();
173+
distributedShape.push_back(o / l);
174+
}
175+
xegpu::TensorDescType distributedDescType;
176+
if (originalT.isScattered()) {
177+
178+
distributedDescType = xegpu::TensorDescType::get(
179+
distributedShape, originalT.getElementType(), originalT.getChunkSize(),
180+
originalT.getMemorySpace(), originalT.getSGMapAttr());
181+
} else {
182+
distributedDescType = xegpu::TensorDescType::get(
183+
distributedShape, originalT.getElementType(),
184+
originalT.getBoundaryCheck(), originalT.getArrayLength(),
185+
originalT.getMemorySpace(), originalT.getSGMapAttr());
186+
}
187+
return distributedDescType;
188+
}
188189
} // namespace
189190

190191
// LogicalResult
@@ -303,91 +304,91 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
303304
// return success();
304305
// }
305306

306-
// LogicalResult
307-
// WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
308-
// PatternRewriter &rewriter) const {
309-
// OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
310-
// return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
311-
// });
312-
313-
// if (!operand)
314-
// return rewriter.notifyMatchFailure(
315-
// warpOp, "warp result is not a xegpu::CreateNdDesc op");
316-
// auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
317-
// assert(descOp && "desc op must be not null");
318-
// unsigned operandIdx = operand->getOperandNumber();
319-
320-
// // TODO: is memref uniform in the region
321-
// rewriter.setInsertionPoint(warpOp);
322-
// auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
323-
// assert(srcTypedVal && "source value must be not null");
324-
325-
// auto descOffsets = descOp.getMixedOffsets();
326-
// if (descOffsets.size() != 2)
327-
// return rewriter.notifyMatchFailure(descOp,
328-
// "offsets size is expected to be 2");
329-
330-
// xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
331-
// if (!sgMap)
332-
// return rewriter.notifyMatchFailure(
333-
// descOp, "the tensor descriptor lacks sg_map attribute");
334-
335-
// auto layout = sgMap.getWiLayout();
336-
337-
// // Calculate the offset within tensor descriptor for the current lane_id. The
338-
// // access to proper element for a work item is done through a lane-specific
339-
// // subview (tdesc offsets are used as base, lane shift is added on top).
340-
// auto laneid = warpOp.getLaneid();
341-
// auto xDim =
342-
// rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
343-
// auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
344-
// auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
345-
346-
// auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
347-
// descOffsets[0]);
348-
// auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
349-
// descOffsets[1]);
350-
// auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
351-
// auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
352-
353-
// auto distributedDescTypeOrFailure = getDistributedTensorDescType(
354-
// descOp.getType(), sgMap, descOp.getType().getMemorySpace());
355-
// if (failed(distributedDescTypeOrFailure))
356-
// return rewriter.notifyMatchFailure(descOp,
357-
// "Failed to distribute the desc type");
358-
// xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
359-
// auto distributedShape = newTDescType.getShape();
360-
// // use the base memref strides
361-
// SmallVector<OpFoldResult> overwriteStrides =
362-
// getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
363-
// SmallVector<OpFoldResult> overwriteSizes =
364-
// getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
365-
366-
// SmallVector<size_t> newRetIndices;
367-
// vector::WarpExecuteOnLane0Op newWarpOp =
368-
// moveRegionToNewWarpOpAndAppendReturns(
369-
// rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
370-
// newRetIndices);
371-
372-
// rewriter.setInsertionPointAfter(newWarpOp);
373-
// auto subview = rewriter.create<memref::SubViewOp>(
374-
// newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
375-
// overwriteSizes, overwriteStrides);
376-
// subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
377-
378-
// auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
379-
// auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
380-
// newWarpOp.getLoc(), newTDescType, subview,
381-
// getAsOpFoldResult({zero, zero}));
382-
383-
// Value distributedVal = newWarpOp.getResult(operandIdx);
384-
// rewriter.replaceAllUsesWith(distributedVal, newDescOp);
385-
386-
// return success();
387-
// }
307+
LogicalResult
308+
WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
309+
PatternRewriter &rewriter) const {
310+
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
311+
return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
312+
});
313+
314+
if (!operand)
315+
return rewriter.notifyMatchFailure(
316+
warpOp, "warp result is not a xegpu::CreateNdDesc op");
317+
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
318+
assert(descOp && "desc op must be not null");
319+
unsigned operandIdx = operand->getOperandNumber();
320+
321+
// TODO: is memref uniform in the region
322+
rewriter.setInsertionPoint(warpOp);
323+
auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
324+
assert(srcTypedVal && "source value must be not null");
325+
326+
auto descOffsets = descOp.getMixedOffsets();
327+
if (descOffsets.size() != 2)
328+
return rewriter.notifyMatchFailure(descOp,
329+
"offsets size is expected to be 2");
330+
331+
xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
332+
if (!sgMap)
333+
return rewriter.notifyMatchFailure(
334+
descOp, "the tensor descriptor lacks sg_map attribute");
335+
336+
auto layout = sgMap.getWiLayout();
337+
338+
// Calculate the offset within tensor descriptor for the current lane_id. The
339+
// access to proper element for a work item is done through a lane-specific
340+
// subview (tdesc offsets are used as base, lane shift is added on top).
341+
auto laneid = warpOp.getLaneid();
342+
auto xDim =
343+
rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
344+
auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
345+
auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
346+
347+
auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
348+
descOffsets[0]);
349+
auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
350+
descOffsets[1]);
351+
auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
352+
auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
353+
354+
auto distributedDescTypeOrFailure = getDistributedTensorDescType(
355+
descOp.getType(), sgMap, descOp.getType().getMemorySpace());
356+
if (failed(distributedDescTypeOrFailure))
357+
return rewriter.notifyMatchFailure(descOp,
358+
"Failed to distribute the desc type");
359+
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
360+
auto distributedShape = newTDescType.getShape();
361+
// use the base memref strides
362+
SmallVector<OpFoldResult> overwriteStrides =
363+
getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
364+
SmallVector<OpFoldResult> overwriteSizes =
365+
getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
366+
367+
SmallVector<size_t> newRetIndices;
368+
gpu::WarpExecuteOnLane0Op newWarpOp =
369+
moveRegionToNewWarpOpAndAppendReturns(
370+
rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
371+
newRetIndices);
372+
373+
rewriter.setInsertionPointAfter(newWarpOp);
374+
auto subview = rewriter.create<memref::SubViewOp>(
375+
newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
376+
overwriteSizes, overwriteStrides);
377+
subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
378+
379+
auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
380+
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
381+
newWarpOp.getLoc(), newTDescType, subview,
382+
getAsOpFoldResult({zero, zero}));
383+
384+
Value distributedVal = newWarpOp.getResult(operandIdx);
385+
rewriter.replaceAllUsesWith(distributedVal, newDescOp);
386+
387+
return success();
388+
}
388389

389390
void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
390-
// patterns.add<WarpOpTensorDescOp>(patterns.getContext());
391+
patterns.add<WarpOpTensorDescOp>(patterns.getContext());
391392
// patterns.add<WarpOpStoreNd>(patterns.getContext());
392393
// patterns.add<WarpOpLoadNd>(patterns.getContext());
393394
}

0 commit comments

Comments
 (0)