@@ -34,11 +34,10 @@ using namespace mlir;
3434namespace {
3535
3636// clang-format off
37- // / This pattern transform the CreateNdDescOp to create a subgroup descriptor
37+ // / This pattern transforms the CreateNdDescOp to create a subgroup descriptor
3838// / from a workgroup descriptor. It replaces the offsets and sizes with
3939// / appropriate values for the subgroup.
40- // / It uses round-robin distribution to create the subgroup descriptor.
41-
40+ // / It uses round-robin assignment to distribute the work to the subgroups.
4241// / Following create_nd_desc operation:,
4342// / %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
4443// / -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
@@ -47,7 +46,7 @@ namespace {
4746// / %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
4847// / !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
4948// /
50- // / The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
49+ // / The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
5150// /
5251// / 24x24 matrix distribution example:
5352// / sg_layout = [4, 4], sg_data = [2, 2]
@@ -72,7 +71,6 @@ namespace {
7271// /
7372// / Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
7473// / distribution units (3x3) in total. Hence the 9 subgroup level operations.
75- // / Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
7674// clang-format on
7775struct WgToSgCreateNdOp : public OpConversionPattern <xegpu::CreateNdDescOp> {
7876 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
@@ -110,7 +108,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
110108 return rewriter.create <arith::ConstantIndexOp>(loc, value);
111109 }
112110
113- // Calculate global offset for each subgroup
111+ // Calculate offset for each subgroup
114112 SmallVector<OpFoldResult>
115113 calculateGlobalOffsets (ConversionPatternRewriter &rewriter, Location loc,
116114 const SmallVector<Value> &originalOffsets,
@@ -122,13 +120,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
122120 Value constOffsetY =
123121 createConstantIndex (rewriter, loc, distUnitBaseAddr[1 ]);
124122
125- // Compute offsets within entire tile
126123 Value offsetX =
127124 rewriter.createOrFold <index::AddOp>(loc, localOffset[0 ], constOffsetX);
128125 Value offsetY =
129126 rewriter.createOrFold <index::AddOp>(loc, localOffset[1 ], constOffsetY);
130127
131- // Add to global offsets
132128 size_t lastDimIndex = originalOffsets.size () - 1 ;
133129 size_t secondLastDimIndex = lastDimIndex - 1 ;
134130
@@ -137,7 +133,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
137133 Value globalOffsetY = rewriter.createOrFold <index::AddOp>(
138134 loc, originalOffsets[lastDimIndex], offsetY);
139135
140- // Create final offset list
141136 SmallVector<OpFoldResult> globalOffsets (originalOffsets.begin (),
142137 originalOffsets.end ());
143138 globalOffsets[secondLastDimIndex] = globalOffsetX;
@@ -172,7 +167,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
172167 sgDataDim[i] = createConstantIndex (rewriter, loc, sgShape[i]);
173168 }
174169
175- // Delinearize the 1D subgroup id into nd coordinates
170+ // Delinearize the 1D subgroup id into 2d
176171 SmallVector<Value> sgIds = delinearizeSubgroupId (
177172 rewriter, loc, linearSgId, sgLayoutDim[0 ], sgLayoutDim[1 ]);
178173
@@ -207,8 +202,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
207202 }
208203};
209204
210- // / This pattern transforms the LoadNdOp to load from a subgroup descriptor
211- // / It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
205+ // / This pattern transforms the LoadNdOp to load subgroup data.
212206struct WgToSgLoadNdOp : public OpConversionPattern <xegpu::LoadNdOp> {
213207 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
214208 LogicalResult
@@ -310,7 +304,22 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
310304 }
311305 }
312306 rewriter.replaceOpWithMultiple (op, {newDpasOps});
313- return mlir::success ();
307+ return success ();
308+ }
309+ };
310+
311+ // / This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
312+ struct WgToSgPrefetchNdOp : public OpConversionPattern <xegpu::PrefetchNdOp> {
313+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
314+ LogicalResult
315+ matchAndRewrite (xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
316+ ConversionPatternRewriter &rewriter) const override {
317+ for (auto src : adaptor.getTensorDesc ()) {
318+ rewriter.create <xegpu::PrefetchNdOp>(op.getLoc (), TypeRange (), src,
319+ op->getAttrs ());
320+ }
321+ rewriter.eraseOp (op);
322+ return success ();
314323 }
315324};
316325
@@ -320,7 +329,8 @@ namespace mlir {
320329namespace xegpu {
321330void populateXeGPUWgToSgPatterns (RewritePatternSet &patterns) {
322331 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323- WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext ());
332+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
333+ patterns.getContext ());
324334}
325335} // namespace xegpu
326336} // namespace mlir
@@ -345,6 +355,8 @@ void XeGPUWgToSgPass::runOnOperation() {
345355 return storeOp.getTensorDescType ();
346356 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
347357 return updateOp.getType ();
358+ if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
359+ return prefetchOp.getTensorDescType ();
348360 return xegpu::TensorDescType ();
349361 };
350362
@@ -353,12 +365,12 @@ void XeGPUWgToSgPass::runOnOperation() {
353365 };
354366
355367 target.addDynamicallyLegalOp <xegpu::CreateNdDescOp, xegpu::LoadNdOp,
356- xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
357- [=](Operation *op) -> bool {
358- auto tdescTy = getTensorDescType (op);
359- auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout ());
360- return isLegal (layout);
361- });
368+ xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
369+ xegpu::PrefetchNdOp>( [=](Operation *op) -> bool {
370+ auto tdescTy = getTensorDescType (op);
371+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout ());
372+ return isLegal (layout);
373+ });
362374
363375 target.addDynamicallyLegalOp <xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
364376 auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
0 commit comments