@@ -182,16 +182,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
182182 layout.dropSgLayoutAndData ());
183183
184184 SmallVector<Value> newCreateNdOps;
185- SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets ();
185+ SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets ();
186186
187187 for (auto tdescOffsets : *maybeTdescOffsets) {
188188 SmallVector<OpFoldResult> sgOffsets;
189189 size_t rank = tdescOffsets.size ();
190190 for (size_t i = 0 ; i < rank; i++) {
191- size_t idx = wgOffsets .size () - rank + i;
191+ size_t idx = origOffsets .size () - rank + i;
192192 Value add = rewriter.createOrFold <index::AddOp>(
193193 loc, tdescOffsets[i],
194- getValueOrCreateConstantIndexOp (rewriter, loc, wgOffsets [idx]));
194+ getValueOrCreateConstantIndexOp (rewriter, loc, origOffsets [idx]));
195195 sgOffsets.push_back (add);
196196 }
197197
@@ -296,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296 }
297297};
298298
299+ // Utility function to compute global offsets for subgroup operations.
300+ // Returns a vector of new offsets for each subgroup, given the original op's
301+ // offsets and subgroup relative offsets.
302+ static SmallVector<SmallVector<OpFoldResult>>
303+ computeOffsets (Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304+ ArrayRef<OpFoldResult> origOffsets,
305+ ConversionPatternRewriter &rewriter) {
306+ SmallVector<SmallVector<OpFoldResult>> finalOffsets;
307+ Location loc = op->getLoc ();
308+ for (const auto &sgOffsets : sgOffsetsList) {
309+ SmallVector<OpFoldResult> newOffsets;
310+ size_t rank = sgOffsets.size ();
311+ for (size_t i = 0 ; i < rank; i++) {
312+ size_t idx = origOffsets.size () - rank + i;
313+ Value add = rewriter.createOrFold <index::AddOp>(
314+ loc, sgOffsets[i],
315+ getValueOrCreateConstantIndexOp (rewriter, loc, origOffsets[idx]));
316+ newOffsets.push_back (add);
317+ }
318+ finalOffsets.push_back (std::move (newOffsets));
319+ }
320+ return finalOffsets;
321+ }
322+
323+ // Utility function to get sgShape, sgOffsetList for a given
324+ // op.
325+ template <typename OpTy, typename AdaptorTy>
326+ LogicalResult getSgOffsets (OpTy op, AdaptorTy adaptor,
327+ ConversionPatternRewriter &rewriter,
328+ SmallVector<int64_t > &sgShape,
329+ SmallVector<SmallVector<Value>> &sgOffsetList) {
330+ int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
331+ if (offsetSize == 0 && (!op.getConstOffsetsAttr ()))
332+ return failure ();
333+
334+ Location loc = op.getLoc ();
335+ Value tdesc = op.getTensorDesc ();
336+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType ());
337+ if (!tdescTy)
338+ return failure ();
339+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout ());
340+ if (!layout)
341+ return failure ();
342+
343+ SmallVector<int64_t > sgLayout;
344+ auto sgLayoutAttr = layout.getSgLayout ();
345+ if (!sgLayoutAttr)
346+ return rewriter.notifyMatchFailure (
347+ op, " sgLayout attribute is required in layout" );
348+ sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
349+
350+ ArrayRef<int64_t > wgShape = tdescTy.getShape ();
351+ int count;
352+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
353+
354+ // Get the subgroup ID
355+ Value linearSgId =
356+ gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
357+
358+ int64_t startOfRange = -1 , endOfRange = -1 ;
359+ bool sgIdRangeSpecified = isSgIdRangeSpecified (op, startOfRange, endOfRange);
360+
361+ if (sgIdRangeSpecified) {
362+ int64_t sgCount = endOfRange - startOfRange;
363+ if (computeProduct (sgLayout) != sgCount)
364+ return rewriter.notifyMatchFailure (
365+ op, " sg_layout size must match the sg_id_range" );
366+ Value startOfRangeVal =
367+ rewriter.create <arith::ConstantIndexOp>(loc, startOfRange);
368+ linearSgId =
369+ rewriter.createOrFold <index::SubOp>(loc, linearSgId, startOfRangeVal);
370+ }
371+
372+ auto sgOffsets = layout.getOffsets (rewriter, loc, linearSgId, wgShape);
373+ if (failed (sgOffsets))
374+ return failure ();
375+
376+ sgOffsetList = *sgOffsets;
377+ return success ();
378+ }
379+
380+ template <typename OpTy>
381+ SmallVector<OpFoldResult> getOffsets (OpTy op,
382+ ConversionPatternRewriter &rewriter) {
383+ SmallVector<OpFoldResult> origOffsets;
384+ if (auto constOffsets = op.getConstOffsetsAttr ()) {
385+ for (auto attr : constOffsets.asArrayRef ())
386+ origOffsets.push_back (rewriter.getIndexAttr (attr));
387+ }
388+ for (auto v : op.getOffsets ())
389+ origOffsets.push_back (v);
390+ return origOffsets;
391+ }
392+
393+ // This pattern transforms the LoadNdOp with explicit offsets to load
394+ // subgroup data.
395+ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern <xegpu::LoadNdOp> {
396+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
397+ LogicalResult
398+ matchAndRewrite (xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
399+ ConversionPatternRewriter &rewriter) const override {
400+
401+ SmallVector<int64_t > sgShape;
402+ SmallVector<SmallVector<Value>> sgOffsetList;
403+
404+ // Do the distribution from workgroup to subgroup and get subgroup offsets
405+ if (failed (getSgOffsets (op, adaptor, rewriter, sgShape, sgOffsetList)))
406+ return failure ();
407+
408+ // Get the original workgroup offsets
409+ SmallVector<OpFoldResult> origOffsets = getOffsets (op, rewriter);
410+
411+ // Calculate the final offsets for each subgroup
412+ auto finalOffsets = computeOffsets (op, sgOffsetList, origOffsets, rewriter);
413+
414+ SmallVector<Value> newLoadOps;
415+ for (auto [offsets, tdesc] :
416+ llvm::zip (finalOffsets, adaptor.getTensorDesc ())) {
417+ VectorType newResTy = VectorType::get (
418+ sgShape,
419+ dyn_cast<xegpu::TensorDescType>(tdesc.getType ()).getElementType ());
420+ auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
421+ op.getLoc (), newResTy, tdesc, offsets,
422+ /* packed=*/ nullptr ,
423+ /* transpose=*/ nullptr , op.getL1HintAttr (), op.getL2HintAttr (),
424+ op.getL3HintAttr ());
425+ newLoadOps.push_back (newLoadOp);
426+ }
427+ rewriter.replaceOpWithMultiple (op, {newLoadOps});
428+ return success ();
429+ }
430+ };
431+
432+ // This pattern transforms the StoreNdOp with explicit offsets to store
433+ // subgroup data.
434+ struct WgToSgStoreNdOpWithOffset
435+ : public OpConversionPattern<xegpu::StoreNdOp> {
436+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
437+ LogicalResult
438+ matchAndRewrite (xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
439+ ConversionPatternRewriter &rewriter) const override {
440+
441+ SmallVector<int64_t > sgShape;
442+ SmallVector<SmallVector<Value>> sgOffsetList;
443+
444+ // Do the distribution from workgroup to subgroup and get subgroup offsets
445+ if (failed (getSgOffsets (op, adaptor, rewriter, sgShape, sgOffsetList)))
446+ return failure ();
447+
448+ // Get the original workgroup offsets
449+ SmallVector<OpFoldResult> origOffsets = getOffsets (op, rewriter);
450+
451+ // Calculate the final offsets for each subgroup
452+ auto finalOffsets = computeOffsets (op, sgOffsetList, origOffsets, rewriter);
453+
454+ for (auto [offsets, tdesc, value] :
455+ llvm::zip (finalOffsets, adaptor.getTensorDesc (), adaptor.getValue ())) {
456+ rewriter.create <xegpu::StoreNdOp>(op.getLoc (), value, tdesc, offsets,
457+ op.getL1HintAttr (), op.getL2HintAttr (),
458+ op.getL3HintAttr ());
459+ }
460+ rewriter.eraseOp (op);
461+ return success ();
462+ }
463+ };
464+
465+ // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
466+ // subgroup data.
467+ struct WgToSgPrefetchNdOpWithOffset
468+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
469+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
470+ LogicalResult
471+ matchAndRewrite (xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
472+ ConversionPatternRewriter &rewriter) const override {
473+
474+ SmallVector<int64_t > sgShape;
475+ SmallVector<SmallVector<Value>> sgOffsetList;
476+
477+ // Do the distribution from workgroup to subgroup and get subgroup offsets
478+ if (failed (getSgOffsets (op, adaptor, rewriter, sgShape, sgOffsetList)))
479+ return failure ();
480+
481+ // Get the original workgroup offsets
482+ SmallVector<OpFoldResult> origOffsets = getOffsets (op, rewriter);
483+
484+ // Calculate the final offsets for each subgroup
485+ auto finalOffsets = computeOffsets (op, sgOffsetList, origOffsets, rewriter);
486+
487+ for (auto [offsets, tdesc] :
488+ llvm::zip (finalOffsets, adaptor.getTensorDesc ())) {
489+ rewriter.create <xegpu::PrefetchNdOp>(
490+ op.getLoc (), tdesc, offsets, op.getL1HintAttr (), op.getL2HintAttr (),
491+ op.getL3HintAttr ());
492+ }
493+ rewriter.eraseOp (op);
494+ return success ();
495+ }
496+ };
497+
299498// / This pattern transforms the UpdateNdOffsetOp to update the offsets of a
300499// / subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
301500// / offsets of the new subgroup src tensor descriptors.
@@ -690,12 +889,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
690889namespace mlir {
691890namespace xegpu {
692891void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
693- patterns.add <WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
694- WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
695- WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
696- WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
697- WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
698- patterns.getContext ());
892+ patterns
893+ .add <WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
894+ WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
895+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
896+ WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
897+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
898+ WgToSgArithConstantOp>(patterns.getContext ());
699899}
700900} // namespace xegpu
701901} // namespace mlir
0 commit comments