@@ -33,11 +33,15 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
3333
3434// /// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
3535// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
36- // /// The warp op will still contain the original op that will not be used by the
37- // /// yield op (and should be cleaned up later with dce). The yield op will bypass
36+ // /// The warp op will still contain the original op that will not be used by
37+ // the
38+ // /// yield op (and should be cleaned up later with dce). The yield op will
39+ // bypass
3840// /// the create_nd_tdesc's arguments.
39- // /// The rewrite will create a subview of the size used by a single work item and
40- // /// appropriate offset. The distributed create_nd_tdesc points into the subview
41+ // /// The rewrite will create a subview of the size used by a single work item
42+ // and
43+ // /// appropriate offset. The distributed create_nd_tdesc points into the
44+ // subview
4145// /// without offset. The tensor descriptor types is distributed according to
4246// /// sg_map attribute.
4347// ///
@@ -75,8 +79,10 @@ struct WarpOpTensorDescOp final
7579};
7680
7781// /// Sink a store_nd feeding into vector.yield op for the enclosing
78- // /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
79- // /// through the warp op interface they would be propagated as returned values.
82+ // /// `vector.warp_execute_on_lane_0`. In case arguments for the store are
83+ // passed
84+ // /// through the warp op interface they would be propagated as returned
85+ // values.
8086// /// Both the stored vector type and tensor descriptor types are distributed
8187// /// according to sg_map attribute.
8288// ///
@@ -97,20 +103,23 @@ struct WarpOpTensorDescOp final
97103// /// ...
98104// /// vector.yield
99105// /// }
100- // /// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
106+ // /// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>,
107+ // !xegpu.tensor_desc<4x1xf32>
101108// ///
102109// /// ```
103- // struct WarpOpStoreNd final
104- // : public OpRewritePattern<vector ::WarpExecuteOnLane0Op> {
105- // using OpRewritePattern<vector ::WarpExecuteOnLane0Op>::OpRewritePattern;
106- // LogicalResult matchAndRewrite(vector ::WarpExecuteOnLane0Op warpOp,
107- // PatternRewriter &rewriter) const override;
108- // };
110+ struct WarpOpStoreNd final
111+ : public OpRewritePattern<gpu ::WarpExecuteOnLane0Op> {
112+ using OpRewritePattern<gpu ::WarpExecuteOnLane0Op>::OpRewritePattern;
113+ LogicalResult matchAndRewrite (gpu ::WarpExecuteOnLane0Op warpOp,
114+ PatternRewriter &rewriter) const override ;
115+ };
109116
110117// /// Clone a load_nd feeding into vector.yield op for the enclosing
111118// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
112- // /// The warp op will still contain the original op that will not be used by the
113- // /// yield op (and should be cleaned up later with dce). The yield op will bypass
119+ // /// The warp op will still contain the original op that will not be used by
120+ // the
121+ // /// yield op (and should be cleaned up later with dce). The yield op will
122+ // bypass
114123// /// the load's arguments.
115124// /// Both the loaded vector type and tensor descriptor types are distributed
116125// /// according to sg_map attribute.
@@ -137,28 +146,27 @@ struct WarpOpTensorDescOp final
137146// /// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
138147// ///
139148// /// ```
140- // struct WarpOpLoadNd final
141- // : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
142- // using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
143- // LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
144- // PatternRewriter &rewriter) const override;
145- // };
146-
147- // FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
148- // xegpu::SGMapAttr sgMap) {
149- // llvm::SmallVector<int64_t, 2> distributedShape;
150- // auto layout = sgMap.getWiLayout();
151- // auto shape = originalT.getShape();
152- // for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
153- // if (!divisible(APInt(64, o), APInt(64, l)))
154- // return failure();
155- // distributedShape.push_back(o / l);
156- // }
157- // auto newVectorType =
158- // VectorType::get(distributedShape, originalT.getElementType(),
159- // originalT.getScalableDims());
160- // return newVectorType;
161- // }
149+ struct WarpOpLoadNd final : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
150+ using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
151+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
152+ PatternRewriter &rewriter) const override ;
153+ };
154+
155+ FailureOr<VectorType> getDistributedVectorType (VectorType originalT,
156+ xegpu::SGMapAttr sgMap) {
157+ llvm::SmallVector<int64_t , 2 > distributedShape;
158+ auto layout = sgMap.getWiLayout ();
159+ auto shape = originalT.getShape ();
160+ for (const auto [l, o] : llvm::zip_equal (layout, shape)) {
161+ if (!divisible (APInt (64 , o), APInt (64 , l)))
162+ return failure ();
163+ distributedShape.push_back (o / l);
164+ }
165+ auto newVectorType =
166+ VectorType::get (distributedShape, originalT.getElementType (),
167+ originalT.getScalableDims ());
168+ return newVectorType;
169+ }
162170
163171FailureOr<xegpu::TensorDescType>
164172getDistributedTensorDescType (xegpu::TensorDescType originalT,
@@ -192,121 +200,117 @@ getDistributedTensorDescType(xegpu::TensorDescType originalT,
192200}
193201} // namespace
194202
195- // LogicalResult
196- // WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
197- // PatternRewriter &rewriter) const {
198- // auto yield = cast<vector::YieldOp>(
199- // warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200- // Operation *lastNode = yield->getPrevNode();
201- // auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
202- // if (!storeOp)
203- // return failure();
204-
205- // auto origType = storeOp.getTensorDescType();
206- // xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
207- // if (!sgMap)
208- // return rewriter.notifyMatchFailure(
209- // storeOp, "the source tensor descriptor lacks sg_map attribute");
210-
211- // if (storeOp.getTensorDescType().getShape().size() != 2)
212- // return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
213- // DBGS() << "Matched store_nd: " << storeOp << "\n";
214-
215- // auto distributedTypeOrFailure =
216- // getDistributedVectorType(storeOp.getValueType(), sgMap);
217- // if (failed(distributedTypeOrFailure))
218- // return rewriter.notifyMatchFailure(storeOp,
219- // "Failed to distribute the type");
220- // VectorType newVectorType = distributedTypeOrFailure.value();
221-
222- // auto distributedDescTypeOrFailure = getDistributedTensorDescType(
223- // storeOp.getTensorDescType(), sgMap,
224- // storeOp.getTensorDescType().getMemorySpace());
225- // if (failed(distributedDescTypeOrFailure))
226- // return rewriter.notifyMatchFailure(storeOp,
227- // "Failed to distribute the desc type");
228- // xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
229-
230- // SmallVector<size_t> newRetIndices;
231- // vector::WarpExecuteOnLane0Op newWarpOp =
232- // moveRegionToNewWarpOpAndAppendReturns(
233- // rewriter, warpOp,
234- // ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
235- // TypeRange{newTDescType, newVectorType}, newRetIndices);
236-
237- // rewriter.setInsertionPointAfter(newWarpOp);
238- // auto newStoreOp =
239- // cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
240- // rewriter.eraseOp(storeOp);
241- // newStoreOp.getTensorDescMutable().assign(
242- // newWarpOp.getResult(newRetIndices[0]));
243- // newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
244-
245- // return success();
246- // }
247-
248- // LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
249- // PatternRewriter &rewriter) const {
250- // OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
251- // return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
252- // });
253-
254- // if (!operand)
255- // return rewriter.notifyMatchFailure(warpOp,
256- // "warp result is not a xegpu::LoadNd op");
257-
258- // auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
259-
260- // if (loadOp.getPacked())
261- // return rewriter.notifyMatchFailure(
262- // loadOp, "Packed load distribution not supported");
263-
264- // xegpu::TensorDescType origType = loadOp.getTensorDescType();
265- // xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
266- // if (!sgMap)
267- // return rewriter.notifyMatchFailure(
268- // loadOp, "the source tensor descriptor lacks sg_map attribute");
269-
270- // auto origShape = origType.getShape();
271- // if (origShape.size() != 2)
272- // return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
273-
274- // auto distributedTypeOrFailure =
275- // getDistributedVectorType(loadOp.getType(), sgMap);
276- // if (failed(distributedTypeOrFailure))
277- // return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
278- // VectorType newVectorType = distributedTypeOrFailure.value();
279-
280- // auto distributedDescTypeOrFailure =
281- // getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
282- // loadOp.getTensorDescType().getMemorySpace());
283- // if (failed(distributedDescTypeOrFailure))
284- // return rewriter.notifyMatchFailure(loadOp,
285- // "Failed to distribute the desc type");
286- // xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
287-
288- // unsigned operandIdx = operand->getOperandNumber();
289-
290- // SmallVector<size_t> newRetIndices;
291- // vector::WarpExecuteOnLane0Op newWarpOp =
292- // moveRegionToNewWarpOpAndAppendReturns(
293- // rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
294- // newRetIndices);
295-
296- // rewriter.setInsertionPointAfter(newWarpOp);
297-
298- // auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
299- // loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
300- // loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
301- // loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
302-
303- // newLoadOp.getTensorDescMutable().assign(
304- // newWarpOp.getResult(newRetIndices[0]));
305- // Value distributedVal = newWarpOp.getResult(operandIdx);
306- // rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
307-
308- // return success();
309- // }
203+ LogicalResult WarpOpStoreNd::matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
204+ PatternRewriter &rewriter) const {
205+ auto yield = cast<gpu::YieldOp>(
206+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
207+ Operation *lastNode = yield->getPrevNode ();
208+ auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
209+ if (!storeOp)
210+ return failure ();
211+
212+ auto origType = storeOp.getTensorDescType ();
213+ xegpu::SGMapAttr sgMap = origType.getSGMapAttr ();
214+ if (!sgMap)
215+ return rewriter.notifyMatchFailure (
216+ storeOp, " the source tensor descriptor lacks sg_map attribute" );
217+
218+ if (storeOp.getTensorDescType ().getShape ().size () != 2 )
219+ return rewriter.notifyMatchFailure (storeOp, " unsupported shape" );
220+ DBGS () << " Matched store_nd: " << storeOp << " \n " ;
221+
222+ auto distributedTypeOrFailure =
223+ getDistributedVectorType (storeOp.getValueType (), sgMap);
224+ if (failed (distributedTypeOrFailure))
225+ return rewriter.notifyMatchFailure (storeOp,
226+ " Failed to distribute the type" );
227+ VectorType newVectorType = distributedTypeOrFailure.value ();
228+
229+ auto distributedDescTypeOrFailure = getDistributedTensorDescType (
230+ storeOp.getTensorDescType (), sgMap,
231+ storeOp.getTensorDescType ().getMemorySpace ());
232+ if (failed (distributedDescTypeOrFailure))
233+ return rewriter.notifyMatchFailure (storeOp,
234+ " Failed to distribute the desc type" );
235+ xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value ();
236+
237+ SmallVector<size_t > newRetIndices;
238+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
239+ rewriter, warpOp, ValueRange{storeOp.getTensorDesc (), storeOp.getValue ()},
240+ TypeRange{newTDescType, newVectorType}, newRetIndices);
241+
242+ rewriter.setInsertionPointAfter (newWarpOp);
243+ auto newStoreOp =
244+ cast<xegpu::StoreNdOp>(rewriter.clone (*storeOp.getOperation ()));
245+ rewriter.eraseOp (storeOp);
246+ newStoreOp.getTensorDescMutable ().assign (
247+ newWarpOp.getResult (newRetIndices[0 ]));
248+ newStoreOp.getValueMutable ().assign (newWarpOp.getResult (newRetIndices[1 ]));
249+
250+ return success ();
251+ }
252+
253+ LogicalResult WarpOpLoadNd::matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
254+ PatternRewriter &rewriter) const {
255+ OpOperand *operand = getWarpResult (warpOp, [](Operation *op) {
256+ return isa<xegpu::LoadNdOp>(op) && op->hasOneUse ();
257+ });
258+
259+ if (!operand)
260+ return rewriter.notifyMatchFailure (warpOp,
261+ " warp result is not a xegpu::LoadNd op" );
262+
263+ auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
264+
265+ if (loadOp.getPacked ())
266+ return rewriter.notifyMatchFailure (
267+ loadOp, " Packed load distribution not supported" );
268+
269+ xegpu::TensorDescType origType = loadOp.getTensorDescType ();
270+ xegpu::SGMapAttr sgMap = origType.getSGMapAttr ();
271+ if (!sgMap)
272+ return rewriter.notifyMatchFailure (
273+ loadOp, " the source tensor descriptor lacks sg_map attribute" );
274+
275+ auto origShape = origType.getShape ();
276+ if (origShape.size () != 2 )
277+ return rewriter.notifyMatchFailure (loadOp, " unsupported shape" );
278+
279+ auto distributedTypeOrFailure =
280+ getDistributedVectorType (loadOp.getType (), sgMap);
281+ if (failed (distributedTypeOrFailure))
282+ return rewriter.notifyMatchFailure (loadOp, " Failed to distribute the type" );
283+ VectorType newVectorType = distributedTypeOrFailure.value ();
284+
285+ auto distributedDescTypeOrFailure =
286+ getDistributedTensorDescType (loadOp.getTensorDescType (), sgMap,
287+ loadOp.getTensorDescType ().getMemorySpace ());
288+ if (failed (distributedDescTypeOrFailure))
289+ return rewriter.notifyMatchFailure (loadOp,
290+ " Failed to distribute the desc type" );
291+ xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value ();
292+
293+ unsigned operandIdx = operand->getOperandNumber ();
294+
295+ SmallVector<size_t > newRetIndices;
296+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
297+ rewriter, warpOp, loadOp.getTensorDesc (), TypeRange{newTDescType},
298+ newRetIndices);
299+
300+ rewriter.setInsertionPointAfter (newWarpOp);
301+
302+ auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
303+ loadOp.getLoc (), newVectorType, loadOp.getTensorDesc (),
304+ loadOp.getPackedAttr (), loadOp.getTransposeAttr (), loadOp.getL1HintAttr (),
305+ loadOp.getL2HintAttr (), loadOp.getL3HintAttr ());
306+
307+ newLoadOp.getTensorDescMutable ().assign (
308+ newWarpOp.getResult (newRetIndices[0 ]));
309+ Value distributedVal = newWarpOp.getResult (operandIdx);
310+ rewriter.replaceAllUsesWith (distributedVal, newLoadOp);
311+
312+ return success ();
313+ }
310314
311315LogicalResult
312316WarpOpTensorDescOp::matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
@@ -369,10 +373,9 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
369373 getAsIndexOpFoldResult (rewriter.getContext (), distributedShape);
370374
371375 SmallVector<size_t > newRetIndices;
372- gpu::WarpExecuteOnLane0Op newWarpOp =
373- moveRegionToNewWarpOpAndAppendReturns (
374- rewriter, warpOp, descOp.getSource (), descOp.getSourceType (),
375- newRetIndices);
376+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
377+ rewriter, warpOp, descOp.getSource (), descOp.getSourceType (),
378+ newRetIndices);
376379
377380 rewriter.setInsertionPointAfter (newWarpOp);
378381 auto subview = rewriter.create <memref::SubViewOp>(
@@ -393,6 +396,6 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
393396
394397void xegpu::populateXeGPUDistributePatterns (RewritePatternSet &patterns) {
395398 patterns.add <WarpOpTensorDescOp>(patterns.getContext ());
396- // patterns.add<WarpOpStoreNd>(patterns.getContext());
397- // patterns.add<WarpOpLoadNd>(patterns.getContext());
399+ patterns.add <WarpOpStoreNd>(patterns.getContext ());
400+ patterns.add <WarpOpLoadNd>(patterns.getContext ());
398401}
0 commit comments