@@ -77,6 +77,82 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
7777 return std::make_pair (sgShape, count);
7878}
7979
80+ // Helper function to compute new offsets for subgroup operations.
81+ static SmallVector<SmallVector<OpFoldResult>>
82+ computeSgOffsets (PatternRewriter &rewriter, Location loc,
83+ xegpu::LayoutAttr layout, Value linearSgId,
84+ ArrayRef<int64_t > wgShape, ArrayRef<OpFoldResult> oldOffsets) {
85+ SmallVector<SmallVector<OpFoldResult>> result;
86+ auto maybeTdescOffsets =
87+ layout.getOffsets (rewriter, loc, linearSgId, wgShape);
88+ if (failed (maybeTdescOffsets))
89+ return result;
90+
91+ for (auto &tdescOffsets : *maybeTdescOffsets) {
92+ SmallVector<OpFoldResult> newOffsets;
93+ size_t rank = tdescOffsets.size ();
94+ for (size_t i = 0 ; i < rank; i++) {
95+ size_t idx = oldOffsets.size () - rank + i;
96+ Value add = rewriter.createOrFold <index::AddOp>(
97+ loc, tdescOffsets[i],
98+ getValueOrCreateConstantIndexOp (rewriter, loc, oldOffsets[idx]));
99+ newOffsets.push_back (add);
100+ }
101+ result.push_back (std::move (newOffsets));
102+ }
103+ return result;
104+ }
105+
106+ // Helper struct to hold extracted subgroup info for ops with explicit offsets.
107+ struct SgOffsetInfo {
108+ Location loc;
109+ Value tdesc;
110+ xegpu::TensorDescType tdescTy;
111+ xegpu::LayoutAttr layout;
112+ SmallVector<int64_t > sgShape;
113+ int count;
114+ Value linearSgId;
115+ SmallVector<OpFoldResult> oldOffsets;
116+ };
117+
118+ // Helper function to extract subgroup info for ops with explicit offsets.
119+ // Returns std::nullopt on failure.
120+ template <typename OpTy>
121+ std::optional<SgOffsetInfo>
122+ extractSgOffsetInfo (OpTy op, ConversionPatternRewriter &rewriter) {
123+ int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
124+ if (offsetSize == 0 )
125+ return std::nullopt ;
126+
127+ Location loc = op.getLoc ();
128+ Value tdesc = op.getTensorDesc ();
129+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType ());
130+ if (!tdescTy)
131+ return std::nullopt ;
132+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout ());
133+ if (!layout)
134+ return std::nullopt ;
135+
136+ ArrayRef<int64_t > wgShape = tdescTy.getShape ();
137+ SmallVector<int64_t > sgShape;
138+ int count;
139+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
140+
141+ Value linearSgId =
142+ gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
143+
144+ SmallVector<OpFoldResult> oldOffsets;
145+ if (auto constOffsets = op.getConstOffsetsAttr ()) {
146+ for (auto attr : constOffsets.asArrayRef ())
147+ oldOffsets.push_back (rewriter.getIndexAttr (attr));
148+ }
149+ for (auto v : op.getOffsets ())
150+ oldOffsets.push_back (v);
151+
152+ return SgOffsetInfo{loc, tdesc, tdescTy, layout,
153+ sgShape, count, linearSgId, oldOffsets};
154+ }
155+
80156// / This pattern transforms the CreateNdDescOp to create a subgroup descriptor
81157// / from a workgroup descriptor. It replaces the offsets and sizes with
82158// / appropriate values for the subgroup.
@@ -275,6 +351,43 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
275351 }
276352};
277353
354+ // This pattern transforms the LoadNdOp with explicit offsets to load subgroup
355+ // data.
356+ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern <xegpu::LoadNdOp> {
357+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
358+
359+ LogicalResult
360+ matchAndRewrite (xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
361+ ConversionPatternRewriter &rewriter) const override {
362+
363+ auto infoOpt = extractSgOffsetInfo (op, rewriter);
364+ if (!infoOpt)
365+ return failure ();
366+ const auto &info = *infoOpt;
367+
368+ auto sgOffsets =
369+ computeSgOffsets (rewriter, info.loc , info.layout , info.linearSgId ,
370+ info.tdescTy .getShape (), info.oldOffsets );
371+ if (sgOffsets.empty ())
372+ return failure ();
373+
374+ SmallVector<Value> newLoadOps;
375+ auto tdescRange = adaptor.getTensorDesc ();
376+ for (auto it : llvm::zip (sgOffsets, tdescRange)) {
377+ VectorType newResTy =
378+ VectorType::get (info.sgShape , info.tdescTy .getElementType ());
379+ auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
380+ info.loc , newResTy, std::get<1 >(it), std::get<0 >(it),
381+ /* packed=*/ nullptr ,
382+ /* transpose=*/ nullptr , op.getL1HintAttr (), op.getL2HintAttr (),
383+ op.getL3HintAttr ());
384+ newLoadOps.push_back (newLoadOp);
385+ }
386+ rewriter.replaceOpWithMultiple (op, {newLoadOps});
387+ return success ();
388+ }
389+ };
390+
278391// / This pattern transforms the StoreNdOp to store to a subgroup descriptor
279392// / It creates a StoreNdOp op to store the updated values to the new subgroup
280393// / src tensor descriptors.
@@ -297,6 +410,39 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
297410 }
298411};
299412
413+ // This pattern transforms the StoreNdOp with explicit offsets to store
414+ // subgroup data.
415+ struct WgToSgStoreNdOpWithOffset
416+ : public OpConversionPattern<xegpu::StoreNdOp> {
417+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
418+
419+ LogicalResult
420+ matchAndRewrite (xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
421+ ConversionPatternRewriter &rewriter) const override {
422+
423+ auto infoOpt = extractSgOffsetInfo (op, rewriter);
424+ if (!infoOpt)
425+ return failure ();
426+ const auto &info = *infoOpt;
427+
428+ auto sgOffsets =
429+ computeSgOffsets (rewriter, info.loc , info.layout , info.linearSgId ,
430+ info.tdescTy .getShape (), info.oldOffsets );
431+ if (sgOffsets.empty ())
432+ return failure ();
433+
434+ auto tdescRange = adaptor.getTensorDesc ();
435+ auto valueRange = adaptor.getValue ();
436+ for (auto it : llvm::zip (sgOffsets, tdescRange, valueRange)) {
437+ rewriter.create <xegpu::StoreNdOp>(
438+ info.loc , std::get<2 >(it), std::get<1 >(it), std::get<0 >(it),
439+ op.getL1HintAttr (), op.getL2HintAttr (), op.getL3HintAttr ());
440+ }
441+ rewriter.eraseOp (op);
442+ return success ();
443+ }
444+ };
445+
300446// / This pattern transforms the UpdateNdOffsetOp to update the offsets of a
301447// / subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
302448// / offsets of the new subgroup src tensor descriptors.
@@ -383,6 +529,38 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
383529 }
384530};
385531
532+ // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
533+ // subgroup data.
534+ struct WgToSgPrefetchNdOpWithOffset
535+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
536+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
537+
538+ LogicalResult
539+ matchAndRewrite (xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
540+ ConversionPatternRewriter &rewriter) const override {
541+
542+ auto infoOpt = extractSgOffsetInfo (op, rewriter);
543+ if (!infoOpt)
544+ return failure ();
545+ const auto &info = *infoOpt;
546+
547+ auto sgOffsets =
548+ computeSgOffsets (rewriter, info.loc , info.layout , info.linearSgId ,
549+ info.tdescTy .getShape (), info.oldOffsets );
550+ if (sgOffsets.empty ())
551+ return failure ();
552+
553+ auto tdescRange = adaptor.getTensorDesc ();
554+ for (auto it : llvm::zip (sgOffsets, tdescRange)) {
555+ rewriter.create <xegpu::PrefetchNdOp>(
556+ info.loc , std::get<1 >(it), std::get<0 >(it), op.getL1HintAttr (),
557+ op.getL2HintAttr (), op.getL3HintAttr ());
558+ }
559+ rewriter.eraseOp (op);
560+ return success ();
561+ }
562+ };
563+
386564// / This pattern transforms vector.broadcast ops to work at subgroup level.
387565struct WgToSgVectorBroadcastOp
388566 : public OpConversionPattern<vector::BroadcastOp> {
0 commit comments