Skip to content

Commit 491625d

Browse files
committed
fix
1 parent 9888c84 commit 491625d

File tree

1 file changed

+161
-158
lines changed

1 file changed

+161
-158
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp

Lines changed: 161 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
3333

3434
// /// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
3535
// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
36-
// /// The warp op will still contain the original op that will not be used by the
37-
// /// yield op (and should be cleaned up later with dce). The yield op will bypass
36+
// /// The warp op will still contain the original op that will not be used by
37+
// the
38+
// /// yield op (and should be cleaned up later with dce). The yield op will
39+
// bypass
3840
// /// the create_nd_tdesc's arguments.
39-
// /// The rewrite will create a subview of the size used by a single work item and
40-
// /// appropriate offset. The distributed create_nd_tdesc points into the subview
41+
// /// The rewrite will create a subview of the size used by a single work item
42+
// and
43+
// /// appropriate offset. The distributed create_nd_tdesc points into the
44+
// subview
4145
// /// without offset. The tensor descriptor types is distributed according to
4246
// /// sg_map attribute.
4347
// ///
@@ -75,8 +79,10 @@ struct WarpOpTensorDescOp final
7579
};
7680

7781
// /// Sink a store_nd feeding into vector.yield op for the enclosing
78-
// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
79-
// /// through the warp op interface they would be propagated as returned values.
82+
// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are
83+
// passed
84+
// /// through the warp op interface they would be propagated as returned
85+
// values.
8086
// /// Both the stored vector type and tensor descriptor types are distributed
8187
// /// according to sg_map attribute.
8288
// ///
@@ -97,20 +103,23 @@ struct WarpOpTensorDescOp final
97103
// /// ...
98104
// /// vector.yield
99105
// /// }
100-
// /// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
106+
// /// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>,
107+
// !xegpu.tensor_desc<4x1xf32>
101108
// ///
102109
// /// ```
103-
// struct WarpOpStoreNd final
104-
// : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
105-
// using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
106-
// LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
107-
// PatternRewriter &rewriter) const override;
108-
// };
110+
struct WarpOpStoreNd final
111+
: public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
112+
using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
113+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
114+
PatternRewriter &rewriter) const override;
115+
};
109116

110117
// /// Clone a load_nd feeding into vector.yield op for the enclosing
111118
// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
112-
// /// The warp op will still contain the original op that will not be used by the
113-
// /// yield op (and should be cleaned up later with dce). The yield op will bypass
119+
// /// The warp op will still contain the original op that will not be used by
120+
// the
121+
// /// yield op (and should be cleaned up later with dce). The yield op will
122+
// bypass
114123
// /// the load's arguments.
115124
// /// Both the loaded vector type and tensor descriptor types are distributed
116125
// /// according to sg_map attribute.
@@ -137,28 +146,27 @@ struct WarpOpTensorDescOp final
137146
// /// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
138147
// ///
139148
// /// ```
140-
// struct WarpOpLoadNd final
141-
// : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
142-
// using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
143-
// LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
144-
// PatternRewriter &rewriter) const override;
145-
// };
146-
147-
// FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
148-
// xegpu::SGMapAttr sgMap) {
149-
// llvm::SmallVector<int64_t, 2> distributedShape;
150-
// auto layout = sgMap.getWiLayout();
151-
// auto shape = originalT.getShape();
152-
// for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
153-
// if (!divisible(APInt(64, o), APInt(64, l)))
154-
// return failure();
155-
// distributedShape.push_back(o / l);
156-
// }
157-
// auto newVectorType =
158-
// VectorType::get(distributedShape, originalT.getElementType(),
159-
// originalT.getScalableDims());
160-
// return newVectorType;
161-
// }
149+
struct WarpOpLoadNd final : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
150+
using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
151+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
152+
PatternRewriter &rewriter) const override;
153+
};
154+
155+
FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
156+
xegpu::SGMapAttr sgMap) {
157+
llvm::SmallVector<int64_t, 2> distributedShape;
158+
auto layout = sgMap.getWiLayout();
159+
auto shape = originalT.getShape();
160+
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
161+
if (!divisible(APInt(64, o), APInt(64, l)))
162+
return failure();
163+
distributedShape.push_back(o / l);
164+
}
165+
auto newVectorType =
166+
VectorType::get(distributedShape, originalT.getElementType(),
167+
originalT.getScalableDims());
168+
return newVectorType;
169+
}
162170

163171
FailureOr<xegpu::TensorDescType>
164172
getDistributedTensorDescType(xegpu::TensorDescType originalT,
@@ -192,121 +200,117 @@ getDistributedTensorDescType(xegpu::TensorDescType originalT,
192200
}
193201
} // namespace
194202

195-
// LogicalResult
196-
// WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
197-
// PatternRewriter &rewriter) const {
198-
// auto yield = cast<vector::YieldOp>(
199-
// warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200-
// Operation *lastNode = yield->getPrevNode();
201-
// auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
202-
// if (!storeOp)
203-
// return failure();
204-
205-
// auto origType = storeOp.getTensorDescType();
206-
// xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
207-
// if (!sgMap)
208-
// return rewriter.notifyMatchFailure(
209-
// storeOp, "the source tensor descriptor lacks sg_map attribute");
210-
211-
// if (storeOp.getTensorDescType().getShape().size() != 2)
212-
// return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
213-
// DBGS() << "Matched store_nd: " << storeOp << "\n";
214-
215-
// auto distributedTypeOrFailure =
216-
// getDistributedVectorType(storeOp.getValueType(), sgMap);
217-
// if (failed(distributedTypeOrFailure))
218-
// return rewriter.notifyMatchFailure(storeOp,
219-
// "Failed to distribute the type");
220-
// VectorType newVectorType = distributedTypeOrFailure.value();
221-
222-
// auto distributedDescTypeOrFailure = getDistributedTensorDescType(
223-
// storeOp.getTensorDescType(), sgMap,
224-
// storeOp.getTensorDescType().getMemorySpace());
225-
// if (failed(distributedDescTypeOrFailure))
226-
// return rewriter.notifyMatchFailure(storeOp,
227-
// "Failed to distribute the desc type");
228-
// xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
229-
230-
// SmallVector<size_t> newRetIndices;
231-
// vector::WarpExecuteOnLane0Op newWarpOp =
232-
// moveRegionToNewWarpOpAndAppendReturns(
233-
// rewriter, warpOp,
234-
// ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
235-
// TypeRange{newTDescType, newVectorType}, newRetIndices);
236-
237-
// rewriter.setInsertionPointAfter(newWarpOp);
238-
// auto newStoreOp =
239-
// cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
240-
// rewriter.eraseOp(storeOp);
241-
// newStoreOp.getTensorDescMutable().assign(
242-
// newWarpOp.getResult(newRetIndices[0]));
243-
// newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
244-
245-
// return success();
246-
// }
247-
248-
// LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
249-
// PatternRewriter &rewriter) const {
250-
// OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
251-
// return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
252-
// });
253-
254-
// if (!operand)
255-
// return rewriter.notifyMatchFailure(warpOp,
256-
// "warp result is not a xegpu::LoadNd op");
257-
258-
// auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
259-
260-
// if (loadOp.getPacked())
261-
// return rewriter.notifyMatchFailure(
262-
// loadOp, "Packed load distribution not supported");
263-
264-
// xegpu::TensorDescType origType = loadOp.getTensorDescType();
265-
// xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
266-
// if (!sgMap)
267-
// return rewriter.notifyMatchFailure(
268-
// loadOp, "the source tensor descriptor lacks sg_map attribute");
269-
270-
// auto origShape = origType.getShape();
271-
// if (origShape.size() != 2)
272-
// return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
273-
274-
// auto distributedTypeOrFailure =
275-
// getDistributedVectorType(loadOp.getType(), sgMap);
276-
// if (failed(distributedTypeOrFailure))
277-
// return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
278-
// VectorType newVectorType = distributedTypeOrFailure.value();
279-
280-
// auto distributedDescTypeOrFailure =
281-
// getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
282-
// loadOp.getTensorDescType().getMemorySpace());
283-
// if (failed(distributedDescTypeOrFailure))
284-
// return rewriter.notifyMatchFailure(loadOp,
285-
// "Failed to distribute the desc type");
286-
// xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
287-
288-
// unsigned operandIdx = operand->getOperandNumber();
289-
290-
// SmallVector<size_t> newRetIndices;
291-
// vector::WarpExecuteOnLane0Op newWarpOp =
292-
// moveRegionToNewWarpOpAndAppendReturns(
293-
// rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
294-
// newRetIndices);
295-
296-
// rewriter.setInsertionPointAfter(newWarpOp);
297-
298-
// auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
299-
// loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
300-
// loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
301-
// loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
302-
303-
// newLoadOp.getTensorDescMutable().assign(
304-
// newWarpOp.getResult(newRetIndices[0]));
305-
// Value distributedVal = newWarpOp.getResult(operandIdx);
306-
// rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
307-
308-
// return success();
309-
// }
203+
LogicalResult WarpOpStoreNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
204+
PatternRewriter &rewriter) const {
205+
auto yield = cast<gpu::YieldOp>(
206+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
207+
Operation *lastNode = yield->getPrevNode();
208+
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
209+
if (!storeOp)
210+
return failure();
211+
212+
auto origType = storeOp.getTensorDescType();
213+
xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
214+
if (!sgMap)
215+
return rewriter.notifyMatchFailure(
216+
storeOp, "the source tensor descriptor lacks sg_map attribute");
217+
218+
if (storeOp.getTensorDescType().getShape().size() != 2)
219+
return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
220+
DBGS() << "Matched store_nd: " << storeOp << "\n";
221+
222+
auto distributedTypeOrFailure =
223+
getDistributedVectorType(storeOp.getValueType(), sgMap);
224+
if (failed(distributedTypeOrFailure))
225+
return rewriter.notifyMatchFailure(storeOp,
226+
"Failed to distribute the type");
227+
VectorType newVectorType = distributedTypeOrFailure.value();
228+
229+
auto distributedDescTypeOrFailure = getDistributedTensorDescType(
230+
storeOp.getTensorDescType(), sgMap,
231+
storeOp.getTensorDescType().getMemorySpace());
232+
if (failed(distributedDescTypeOrFailure))
233+
return rewriter.notifyMatchFailure(storeOp,
234+
"Failed to distribute the desc type");
235+
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
236+
237+
SmallVector<size_t> newRetIndices;
238+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
239+
rewriter, warpOp, ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
240+
TypeRange{newTDescType, newVectorType}, newRetIndices);
241+
242+
rewriter.setInsertionPointAfter(newWarpOp);
243+
auto newStoreOp =
244+
cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
245+
rewriter.eraseOp(storeOp);
246+
newStoreOp.getTensorDescMutable().assign(
247+
newWarpOp.getResult(newRetIndices[0]));
248+
newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
249+
250+
return success();
251+
}
252+
253+
LogicalResult WarpOpLoadNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
254+
PatternRewriter &rewriter) const {
255+
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
256+
return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
257+
});
258+
259+
if (!operand)
260+
return rewriter.notifyMatchFailure(warpOp,
261+
"warp result is not a xegpu::LoadNd op");
262+
263+
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
264+
265+
if (loadOp.getPacked())
266+
return rewriter.notifyMatchFailure(
267+
loadOp, "Packed load distribution not supported");
268+
269+
xegpu::TensorDescType origType = loadOp.getTensorDescType();
270+
xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
271+
if (!sgMap)
272+
return rewriter.notifyMatchFailure(
273+
loadOp, "the source tensor descriptor lacks sg_map attribute");
274+
275+
auto origShape = origType.getShape();
276+
if (origShape.size() != 2)
277+
return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
278+
279+
auto distributedTypeOrFailure =
280+
getDistributedVectorType(loadOp.getType(), sgMap);
281+
if (failed(distributedTypeOrFailure))
282+
return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
283+
VectorType newVectorType = distributedTypeOrFailure.value();
284+
285+
auto distributedDescTypeOrFailure =
286+
getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
287+
loadOp.getTensorDescType().getMemorySpace());
288+
if (failed(distributedDescTypeOrFailure))
289+
return rewriter.notifyMatchFailure(loadOp,
290+
"Failed to distribute the desc type");
291+
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
292+
293+
unsigned operandIdx = operand->getOperandNumber();
294+
295+
SmallVector<size_t> newRetIndices;
296+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
297+
rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
298+
newRetIndices);
299+
300+
rewriter.setInsertionPointAfter(newWarpOp);
301+
302+
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
303+
loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
304+
loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
305+
loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
306+
307+
newLoadOp.getTensorDescMutable().assign(
308+
newWarpOp.getResult(newRetIndices[0]));
309+
Value distributedVal = newWarpOp.getResult(operandIdx);
310+
rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
311+
312+
return success();
313+
}
310314

311315
LogicalResult
312316
WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -369,10 +373,9 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
369373
getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
370374

371375
SmallVector<size_t> newRetIndices;
372-
gpu::WarpExecuteOnLane0Op newWarpOp =
373-
moveRegionToNewWarpOpAndAppendReturns(
374-
rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
375-
newRetIndices);
376+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
377+
rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
378+
newRetIndices);
376379

377380
rewriter.setInsertionPointAfter(newWarpOp);
378381
auto subview = rewriter.create<memref::SubViewOp>(
@@ -393,6 +396,6 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
393396

394397
void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
395398
patterns.add<WarpOpTensorDescOp>(patterns.getContext());
396-
// patterns.add<WarpOpStoreNd>(patterns.getContext());
397-
// patterns.add<WarpOpLoadNd>(patterns.getContext());
399+
patterns.add<WarpOpStoreNd>(patterns.getContext());
400+
patterns.add<WarpOpLoadNd>(patterns.getContext());
398401
}

0 commit comments

Comments
 (0)