@@ -296,10 +296,38 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296 }
297297};
298298
299- template <typename OpTy, typename AdaptorTy, typename CreateFn>
300- LogicalResult distributeNdOpWithOffset (OpTy op, AdaptorTy adaptor,
301- ConversionPatternRewriter &rewriter,
302- CreateFn &&createOp) {
299+ // Utility function to compute distributed offsets for subgroup operations.
300+ // Returns a vector of new offsets for each subgroup, given the original op's
301+ // offsets and subgroup relative offsets.
302+ static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets (
303+ Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304+ ArrayRef<OpFoldResult> wgOffsets, ConversionPatternRewriter &rewriter) {
305+ SmallVector<SmallVector<OpFoldResult>> distributedOffsets;
306+ Location loc = op->getLoc ();
307+ for (const auto &sgOffsets : sgOffsetsList) {
308+ SmallVector<OpFoldResult> newOffsets;
309+ size_t rank = sgOffsets.size ();
310+ for (size_t i = 0 ; i < rank; i++) {
311+ size_t idx = wgOffsets.size () - rank + i;
312+ Value add = rewriter.createOrFold <index::AddOp>(
313+ loc, sgOffsets[i],
314+ getValueOrCreateConstantIndexOp (rewriter, loc, wgOffsets[idx]));
315+ newOffsets.push_back (add);
316+ }
317+ distributedOffsets.push_back (std::move (newOffsets));
318+ }
319+ return distributedOffsets;
320+ }
321+
322+ // Utility function to get sgShape, sgOffsetList, and wgOffsets for a given
323+ // op.
324+ template <typename OpTy, typename AdaptorTy>
325+ LogicalResult
326+ prepareOpDistribution (OpTy op, AdaptorTy adaptor,
327+ ConversionPatternRewriter &rewriter,
328+ SmallVector<int64_t > &sgShape,
329+ SmallVector<SmallVector<Value>> &sgOffsetList,
330+ SmallVector<OpFoldResult> &wgOffsets) {
303331 int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
304332 if (offsetSize == 0 && (!op.getConstOffsetsAttr ()))
305333 return failure ();
@@ -321,7 +349,6 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
321349 op, " sgLayout attribute is required in layout" );
322350
323351 ArrayRef<int64_t > wgShape = tdescTy.getShape ();
324- SmallVector<int64_t > sgShape;
325352 int count;
326353 std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
327354
@@ -343,21 +370,19 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
343370 rewriter.createOrFold <index::SubOp>(loc, linearSgId, startOfRangeVal);
344371 }
345372
346- auto maybeTdescOffsets =
347- layout.getOffsets (rewriter, loc, linearSgId, wgShape);
348- if (failed (maybeTdescOffsets))
373+ auto sgOffsets = layout.getOffsets (rewriter, loc, linearSgId, wgShape);
374+ if (failed (sgOffsets))
349375 return failure ();
350376
351- SmallVector<OpFoldResult> oldOffsets;
352377 if (auto constOffsets = op.getConstOffsetsAttr ()) {
353378 for (auto attr : constOffsets.asArrayRef ())
354- oldOffsets .push_back (rewriter.getIndexAttr (attr));
379+ wgOffsets .push_back (rewriter.getIndexAttr (attr));
355380 }
356381 for (auto v : op.getOffsets ())
357- oldOffsets .push_back (v);
382+ wgOffsets .push_back (v);
358383
359- return createOp (loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
360- rewriter, op );
384+ sgOffsetList = *sgOffsets;
385+ return success ( );
361386}
362387
363388// This pattern transforms the LoadNdOp with explicit offsets to load
@@ -368,39 +393,31 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
368393 xegpu::LoadNdOp op,
369394 typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
370395 ConversionPatternRewriter &rewriter) const override {
371- return distributeNdOpWithOffset (
372- op, adaptor, rewriter,
373- [](Location loc, SmallVector<int64_t > &sgShape,
374- ArrayRef<SmallVector<Value>> tdescOffsetsList,
375- SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
376- ConversionPatternRewriter &rewriter,
377- xegpu::LoadNdOp &op) -> LogicalResult {
378- SmallVector<Value> newLoadOps;
379- for (auto [tdescOffsets, tdesc] :
380- llvm::zip (tdescOffsetsList, adaptor.getTensorDesc ())) {
381- SmallVector<OpFoldResult> newOffsets;
382- size_t rank = tdescOffsets.size ();
383- for (size_t i = 0 ; i < rank; i++) {
384- size_t idx = oldOffsets.size () - rank + i;
385- Value add = rewriter.createOrFold <index::AddOp>(
386- loc, tdescOffsets[i],
387- getValueOrCreateConstantIndexOp (rewriter, loc,
388- oldOffsets[idx]));
389- newOffsets.push_back (add);
390- }
391- VectorType newResTy = VectorType::get (
392- sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType ())
393- .getElementType ());
394- auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
395- loc, newResTy, tdesc, newOffsets,
396- /* packed=*/ nullptr ,
397- /* transpose=*/ nullptr , op.getL1HintAttr (), op.getL2HintAttr (),
398- op.getL3HintAttr ());
399- newLoadOps.push_back (newLoadOp);
400- }
401- rewriter.replaceOpWithMultiple (op, {newLoadOps});
402- return success ();
403- });
396+ SmallVector<int64_t > sgShape;
397+ SmallVector<SmallVector<Value>> sgOffsetList;
398+ SmallVector<OpFoldResult> wgOffsets;
399+ if (failed (prepareOpDistribution (op, adaptor, rewriter, sgShape,
400+ sgOffsetList, wgOffsets)))
401+ return failure ();
402+
403+ auto distributedOffsets =
404+ computeDistributedOffsets (op, sgOffsetList, wgOffsets, rewriter);
405+
406+ SmallVector<Value> newLoadOps;
407+ for (auto [newOffsets, tdesc] :
408+ llvm::zip (distributedOffsets, adaptor.getTensorDesc ())) {
409+ VectorType newResTy = VectorType::get (
410+ sgShape,
411+ dyn_cast<xegpu::TensorDescType>(tdesc.getType ()).getElementType ());
412+ auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
413+ op.getLoc (), newResTy, tdesc, newOffsets,
414+ /* packed=*/ nullptr ,
415+ /* transpose=*/ nullptr , op.getL1HintAttr (), op.getL2HintAttr (),
416+ op.getL3HintAttr ());
417+ newLoadOps.push_back (newLoadOp);
418+ }
419+ rewriter.replaceOpWithMultiple (op, {newLoadOps});
420+ return success ();
404421 }
405422};
406423
@@ -413,33 +430,24 @@ struct WgToSgStoreNdOpWithOffset
413430 xegpu::StoreNdOp op,
414431 typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
415432 ConversionPatternRewriter &rewriter) const override {
416- return distributeNdOpWithOffset (
417- op, adaptor, rewriter,
418- [](Location loc, SmallVector<int64_t > &sgShape,
419- ArrayRef<SmallVector<Value>> tdescOffsetsList,
420- SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
421- ConversionPatternRewriter &rewriter,
422- xegpu::StoreNdOp &op) -> LogicalResult {
423- for (auto [tdescOffsets, tdesc, value] :
424- llvm::zip (tdescOffsetsList, adaptor.getTensorDesc (),
425- adaptor.getValue ())) {
426- SmallVector<OpFoldResult> newOffsets;
427- size_t rank = tdescOffsets.size ();
428- for (size_t i = 0 ; i < rank; i++) {
429- size_t idx = oldOffsets.size () - rank + i;
430- Value add = rewriter.createOrFold <index::AddOp>(
431- loc, tdescOffsets[i],
432- getValueOrCreateConstantIndexOp (rewriter, loc,
433- oldOffsets[idx]));
434- newOffsets.push_back (add);
435- }
436- rewriter.create <xegpu::StoreNdOp>(
437- loc, value, tdesc, newOffsets, op.getL1HintAttr (),
438- op.getL2HintAttr (), op.getL3HintAttr ());
439- }
440- rewriter.eraseOp (op);
441- return success ();
442- });
433+ SmallVector<int64_t > sgShape;
434+ SmallVector<SmallVector<Value>> sgOffsetList;
435+ SmallVector<OpFoldResult> wgOffsets;
436+ if (failed (prepareOpDistribution (op, adaptor, rewriter, sgShape,
437+ sgOffsetList, wgOffsets)))
438+ return failure ();
439+
440+ auto distributedOffsets =
441+ computeDistributedOffsets (op, sgOffsetList, wgOffsets, rewriter);
442+
443+ for (auto [newOffsets, tdesc, value] : llvm::zip (
444+ distributedOffsets, adaptor.getTensorDesc (), adaptor.getValue ())) {
445+ rewriter.create <xegpu::StoreNdOp>(op.getLoc (), value, tdesc, newOffsets,
446+ op.getL1HintAttr (), op.getL2HintAttr (),
447+ op.getL3HintAttr ());
448+ }
449+ rewriter.eraseOp (op);
450+ return success ();
443451 }
444452};
445453
@@ -453,32 +461,24 @@ struct WgToSgPrefetchNdOpWithOffset
453461 typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
454462 adaptor,
455463 ConversionPatternRewriter &rewriter) const override {
456- return distributeNdOpWithOffset (
457- op, adaptor, rewriter,
458- [](Location loc, SmallVector<int64_t > &sgShape,
459- ArrayRef<SmallVector<Value>> tdescOffsetsList,
460- SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
461- ConversionPatternRewriter &rewriter,
462- xegpu::PrefetchNdOp &op) -> LogicalResult {
463- for (auto [tdescOffsets, tdesc] :
464- llvm::zip (tdescOffsetsList, adaptor.getTensorDesc ())) {
465- SmallVector<OpFoldResult> newOffsets;
466- size_t rank = tdescOffsets.size ();
467- for (size_t i = 0 ; i < rank; i++) {
468- size_t idx = oldOffsets.size () - rank + i;
469- Value add = rewriter.createOrFold <index::AddOp>(
470- loc, tdescOffsets[i],
471- getValueOrCreateConstantIndexOp (rewriter, loc,
472- oldOffsets[idx]));
473- newOffsets.push_back (add);
474- }
475- rewriter.create <xegpu::PrefetchNdOp>(
476- loc, tdesc, newOffsets, op.getL1HintAttr (), op.getL2HintAttr (),
477- op.getL3HintAttr ());
478- }
479- rewriter.eraseOp (op);
480- return success ();
481- });
464+ SmallVector<int64_t > sgShape;
465+ SmallVector<SmallVector<Value>> sgOffsetList;
466+ SmallVector<OpFoldResult> wgOffsets;
467+ if (failed (prepareOpDistribution (op, adaptor, rewriter, sgShape,
468+ sgOffsetList, wgOffsets)))
469+ return failure ();
470+
471+ auto distributedOffsets =
472+ computeDistributedOffsets (op, sgOffsetList, wgOffsets, rewriter);
473+
474+ for (auto [newOffsets, tdesc] :
475+ llvm::zip (distributedOffsets, adaptor.getTensorDesc ())) {
476+ rewriter.create <xegpu::PrefetchNdOp>(
477+ op.getLoc (), tdesc, newOffsets, op.getL1HintAttr (),
478+ op.getL2HintAttr (), op.getL3HintAttr ());
479+ }
480+ rewriter.eraseOp (op);
481+ return success ();
482482 }
483483};
484484
0 commit comments