@@ -77,83 +77,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
7777 return std::make_pair (sgShape, count);
7878}
7979
80- // Helper function to compute new offsets for subgroup operations.
81- static SmallVector<SmallVector<OpFoldResult>>
82- computeSgOffsets (PatternRewriter &rewriter, Location loc,
83- xegpu::LayoutAttr layout, Value linearSgId,
84- ArrayRef<int64_t > wgShape, ArrayRef<OpFoldResult> oldOffsets) {
85- SmallVector<SmallVector<OpFoldResult>> result;
86- auto maybeTdescOffsets =
87- layout.getOffsets (rewriter, loc, linearSgId, wgShape);
88- if (failed (maybeTdescOffsets))
89- return result;
90-
91- for (auto &tdescOffsets : *maybeTdescOffsets) {
92- SmallVector<OpFoldResult> newOffsets;
93- size_t rank = tdescOffsets.size ();
94- for (size_t i = 0 ; i < rank; i++) {
95- size_t idx = oldOffsets.size () - rank + i;
96- Value add = rewriter.createOrFold <index::AddOp>(
97- loc, tdescOffsets[i],
98- getValueOrCreateConstantIndexOp (rewriter, loc, oldOffsets[idx]));
99- newOffsets.push_back (add);
100- }
101- result.push_back (std::move (newOffsets));
102- }
103- return result;
104- }
105-
106- // Helper struct to hold extracted subgroup info for ops with explicit offsets.
107- struct SgOffsetInfo {
108- Location loc;
109- Value tdesc;
110- xegpu::TensorDescType tdescTy;
111- xegpu::LayoutAttr layout;
112- SmallVector<int64_t > sgShape;
113- int count;
114- Value linearSgId;
115- SmallVector<OpFoldResult> oldOffsets;
116- };
117-
118- // Helper function to extract subgroup info for ops with explicit offsets.
119- // Returns std::nullopt on failure.
120- template <typename OpTy>
121- std::optional<SgOffsetInfo>
122- extractSgOffsetInfo (OpTy op, ConversionPatternRewriter &rewriter) {
123-
124- int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
125- if (offsetSize == 0 && (!op.getConstOffsetsAttr ()))
126- return std::nullopt ;
127-
128- Location loc = op.getLoc ();
129- Value tdesc = op.getTensorDesc ();
130- auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType ());
131- if (!tdescTy)
132- return std::nullopt ;
133- auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout ());
134- if (!layout)
135- return std::nullopt ;
136-
137- ArrayRef<int64_t > wgShape = tdescTy.getShape ();
138- SmallVector<int64_t > sgShape;
139- int count;
140- std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
141-
142- Value linearSgId =
143- gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
144-
145- SmallVector<OpFoldResult> oldOffsets;
146- if (auto constOffsets = op.getConstOffsetsAttr ()) {
147- for (auto attr : constOffsets.asArrayRef ())
148- oldOffsets.push_back (rewriter.getIndexAttr (attr));
149- }
150- for (auto v : op.getOffsets ())
151- oldOffsets.push_back (v);
152-
153- return SgOffsetInfo{loc, tdesc, tdescTy, layout,
154- sgShape, count, linearSgId, oldOffsets};
155- }
156-
15780// / This pattern transforms the CreateNdDescOp to create a subgroup descriptor
15881// / from a workgroup descriptor. It replaces the offsets and sizes with
15982// / appropriate values for the subgroup.
@@ -351,43 +274,6 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
351274 }
352275};
353276
354- // This pattern transforms the LoadNdOp with explicit offsets to load subgroup
355- // data.
356- struct WgToSgLoadNdOpWithOffset : public OpConversionPattern <xegpu::LoadNdOp> {
357- using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
358-
359- LogicalResult
360- matchAndRewrite (xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
361- ConversionPatternRewriter &rewriter) const override {
362-
363- auto infoOpt = extractSgOffsetInfo (op, rewriter);
364- if (!infoOpt)
365- return failure ();
366- const auto &info = *infoOpt;
367-
368- auto sgOffsets =
369- computeSgOffsets (rewriter, info.loc , info.layout , info.linearSgId ,
370- info.tdescTy .getShape (), info.oldOffsets );
371- if (sgOffsets.empty ())
372- return failure ();
373-
374- SmallVector<Value> newLoadOps;
375- auto tdescRange = adaptor.getTensorDesc ();
376- for (auto it : llvm::zip (sgOffsets, tdescRange)) {
377- VectorType newResTy =
378- VectorType::get (info.sgShape , info.tdescTy .getElementType ());
379- auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
380- info.loc , newResTy, std::get<1 >(it), std::get<0 >(it),
381- /* packed=*/ nullptr ,
382- /* transpose=*/ nullptr , op.getL1HintAttr (), op.getL2HintAttr (),
383- op.getL3HintAttr ());
384- newLoadOps.push_back (newLoadOp);
385- }
386- rewriter.replaceOpWithMultiple (op, {newLoadOps});
387- return success ();
388- }
389- };
390-
391277// / This pattern transforms the StoreNdOp to store to a subgroup descriptor
392278// / It creates a StoreNdOp op to store the updated values to the new subgroup
393279// / src tensor descriptors.
@@ -410,36 +296,192 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
410296 }
411297};
412298
299+ // This pattern transforms the LoadNdOp with explicit offsets to load subgroup
300+ // data.
301+ // Use a template parameter for the adaptor type
302+ template <typename OpTy, typename AdaptorTy, typename CreateFn>
303+ LogicalResult distributeNdOpWithOffset (OpTy op, AdaptorTy adaptor,
304+ ConversionPatternRewriter &rewriter,
305+ CreateFn &&createOp) {
306+ int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
307+ if (offsetSize == 0 && (!op.getConstOffsetsAttr ()))
308+ return failure ();
309+
310+ Location loc = op.getLoc ();
311+ Value tdesc = op.getTensorDesc ();
312+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType ());
313+ if (!tdescTy)
314+ return failure ();
315+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout ());
316+ if (!layout)
317+ return failure ();
318+
319+ SmallVector<int64_t > sgLayout;
320+ if (auto sgLayoutAttr = layout.getSgLayout ())
321+ sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
322+ else
323+ return rewriter.notifyMatchFailure (
324+ op, " sgLayout attribute is required in layout" );
325+
326+ ArrayRef<int64_t > wgShape = tdescTy.getShape ();
327+ SmallVector<int64_t > sgShape;
328+ int count;
329+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
330+
331+ // Get the subgroup ID
332+ Value linearSgId =
333+ gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
334+
335+ int64_t startOfRange = -1 , endOfRange = -1 ;
336+ bool sgIdRangeSpecified = isSgIdRangeSpecified (op, startOfRange, endOfRange);
337+
338+ if (sgIdRangeSpecified) {
339+ int64_t sgCount = endOfRange - startOfRange;
340+ if (computeProduct (sgLayout) != sgCount)
341+ return rewriter.notifyMatchFailure (
342+ op, " sg_layout size must match the sg_id_range" );
343+ Value startOfRangeVal =
344+ rewriter.create <arith::ConstantIndexOp>(loc, startOfRange);
345+ linearSgId =
346+ rewriter.createOrFold <index::SubOp>(loc, linearSgId, startOfRangeVal);
347+ }
348+
349+ auto maybeTdescOffsets =
350+ layout.getOffsets (rewriter, loc, linearSgId, wgShape);
351+ if (failed (maybeTdescOffsets))
352+ return failure ();
353+
354+ SmallVector<OpFoldResult> oldOffsets;
355+ if (auto constOffsets = op.getConstOffsetsAttr ()) {
356+ for (auto attr : constOffsets.asArrayRef ())
357+ oldOffsets.push_back (rewriter.getIndexAttr (attr));
358+ }
359+ for (auto v : op.getOffsets ())
360+ oldOffsets.push_back (v);
361+
362+ // Delegate to the operation-specific creation function
363+ return createOp (loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
364+ rewriter, op);
365+ }
366+
367+ // Usage for LoadNdOp
368+ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern <xegpu::LoadNdOp> {
369+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
370+ LogicalResult matchAndRewrite (
371+ xegpu::LoadNdOp op,
372+ typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
373+ ConversionPatternRewriter &rewriter) const override {
374+ return distributeNdOpWithOffset (
375+ op, adaptor, rewriter,
376+ [](Location loc, SmallVector<int64_t > &sgShape,
377+ ArrayRef<SmallVector<Value>> tdescOffsetsList,
378+ SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
379+ ConversionPatternRewriter &rewriter,
380+ xegpu::LoadNdOp &op) -> LogicalResult {
381+ SmallVector<Value> newLoadOps;
382+ for (auto [tdescOffsets, tdesc] :
383+ llvm::zip (tdescOffsetsList, adaptor.getTensorDesc ())) {
384+ SmallVector<OpFoldResult> newOffsets;
385+ size_t rank = tdescOffsets.size ();
386+ for (size_t i = 0 ; i < rank; i++) {
387+ size_t idx = oldOffsets.size () - rank + i;
388+ Value add = rewriter.createOrFold <index::AddOp>(
389+ loc, tdescOffsets[i],
390+ getValueOrCreateConstantIndexOp (rewriter, loc,
391+ oldOffsets[idx]));
392+ newOffsets.push_back (add);
393+ }
394+ VectorType newResTy = VectorType::get (
395+ sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType ())
396+ .getElementType ());
397+ auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
398+ loc, newResTy, tdesc, newOffsets,
399+ /* packed=*/ nullptr ,
400+ /* transpose=*/ nullptr , op.getL1HintAttr (), op.getL2HintAttr (),
401+ op.getL3HintAttr ());
402+ newLoadOps.push_back (newLoadOp);
403+ }
404+ rewriter.replaceOpWithMultiple (op, {newLoadOps});
405+ return success ();
406+ });
407+ }
408+ };
409+
413410// This pattern transforms the StoreNdOp with explicit offsets to store
414411// subgroup data.
415412struct WgToSgStoreNdOpWithOffset
416413 : public OpConversionPattern<xegpu::StoreNdOp> {
417414 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
415+ LogicalResult matchAndRewrite (
416+ xegpu::StoreNdOp op,
417+ typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
418+ ConversionPatternRewriter &rewriter) const override {
419+ return distributeNdOpWithOffset (
420+ op, adaptor, rewriter,
421+ [](Location loc, SmallVector<int64_t > &sgShape,
422+ ArrayRef<SmallVector<Value>> tdescOffsetsList,
423+ SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
424+ ConversionPatternRewriter &rewriter,
425+ xegpu::StoreNdOp &op) -> LogicalResult {
426+ for (auto [tdescOffsets, tdesc, value] :
427+ llvm::zip (tdescOffsetsList, adaptor.getTensorDesc (),
428+ adaptor.getValue ())) {
429+ SmallVector<OpFoldResult> newOffsets;
430+ size_t rank = tdescOffsets.size ();
431+ for (size_t i = 0 ; i < rank; i++) {
432+ size_t idx = oldOffsets.size () - rank + i;
433+ Value add = rewriter.createOrFold <index::AddOp>(
434+ loc, tdescOffsets[i],
435+ getValueOrCreateConstantIndexOp (rewriter, loc,
436+ oldOffsets[idx]));
437+ newOffsets.push_back (add);
438+ }
439+ rewriter.create <xegpu::StoreNdOp>(
440+ loc, value, tdesc, newOffsets, op.getL1HintAttr (),
441+ op.getL2HintAttr (), op.getL3HintAttr ());
442+ }
443+ rewriter.eraseOp (op);
444+ return success ();
445+ });
446+ }
447+ };
418448
419- LogicalResult
420- matchAndRewrite (xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
421- ConversionPatternRewriter &rewriter) const override {
422-
423- auto infoOpt = extractSgOffsetInfo (op, rewriter);
424- if (!infoOpt)
425- return failure ();
426- const auto &info = *infoOpt;
427-
428- auto sgOffsets =
429- computeSgOffsets (rewriter, info.loc , info.layout , info.linearSgId ,
430- info.tdescTy .getShape (), info.oldOffsets );
431- if (sgOffsets.empty ())
432- return failure ();
433-
434- auto tdescRange = adaptor.getTensorDesc ();
435- auto valueRange = adaptor.getValue ();
436- for (auto it : llvm::zip (sgOffsets, tdescRange, valueRange)) {
437- rewriter.create <xegpu::StoreNdOp>(
438- info.loc , std::get<2 >(it), std::get<1 >(it), std::get<0 >(it),
439- op.getL1HintAttr (), op.getL2HintAttr (), op.getL3HintAttr ());
440- }
441- rewriter.eraseOp (op);
442- return success ();
449+ // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
450+ // subgroup data.
451+ struct WgToSgPrefetchNdOpWithOffset
452+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
453+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
454+ LogicalResult matchAndRewrite (
455+ xegpu::PrefetchNdOp op,
456+ typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
457+ adaptor,
458+ ConversionPatternRewriter &rewriter) const override {
459+ return distributeNdOpWithOffset (
460+ op, adaptor, rewriter,
461+ [](Location loc, SmallVector<int64_t > &sgShape,
462+ ArrayRef<SmallVector<Value>> tdescOffsetsList,
463+ SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
464+ ConversionPatternRewriter &rewriter,
465+ xegpu::PrefetchNdOp &op) -> LogicalResult {
466+ for (auto [tdescOffsets, tdesc] :
467+ llvm::zip (tdescOffsetsList, adaptor.getTensorDesc ())) {
468+ SmallVector<OpFoldResult> newOffsets;
469+ size_t rank = tdescOffsets.size ();
470+ for (size_t i = 0 ; i < rank; i++) {
471+ size_t idx = oldOffsets.size () - rank + i;
472+ Value add = rewriter.createOrFold <index::AddOp>(
473+ loc, tdescOffsets[i],
474+ getValueOrCreateConstantIndexOp (rewriter, loc,
475+ oldOffsets[idx]));
476+ newOffsets.push_back (add);
477+ }
478+ rewriter.create <xegpu::PrefetchNdOp>(
479+ loc, tdesc, newOffsets, op.getL1HintAttr (), op.getL2HintAttr (),
480+ op.getL3HintAttr ());
481+ }
482+ rewriter.eraseOp (op);
483+ return success ();
484+ });
443485 }
444486};
445487
@@ -529,38 +571,6 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
529571 }
530572};
531573
532- // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
533- // subgroup data.
534- struct WgToSgPrefetchNdOpWithOffset
535- : public OpConversionPattern<xegpu::PrefetchNdOp> {
536- using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
537-
538- LogicalResult
539- matchAndRewrite (xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
540- ConversionPatternRewriter &rewriter) const override {
541-
542- auto infoOpt = extractSgOffsetInfo (op, rewriter);
543- if (!infoOpt)
544- return failure ();
545- const auto &info = *infoOpt;
546-
547- auto sgOffsets =
548- computeSgOffsets (rewriter, info.loc , info.layout , info.linearSgId ,
549- info.tdescTy .getShape (), info.oldOffsets );
550- if (sgOffsets.empty ())
551- return failure ();
552-
553- auto tdescRange = adaptor.getTensorDesc ();
554- for (auto it : llvm::zip (sgOffsets, tdescRange)) {
555- rewriter.create <xegpu::PrefetchNdOp>(
556- info.loc , std::get<1 >(it), std::get<0 >(it), op.getL1HintAttr (),
557- op.getL2HintAttr (), op.getL3HintAttr ());
558- }
559- rewriter.eraseOp (op);
560- return success ();
561- }
562- };
563-
564574// / This pattern transforms vector.broadcast ops to work at subgroup level.
565575struct WgToSgVectorBroadcastOp
566576 : public OpConversionPattern<vector::BroadcastOp> {
0 commit comments