@@ -635,6 +635,50 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
635
635
}
636
636
};
637
637
638
+ // / Slice only offset and keep base - i.e.,
639
+ // / slice(fatPtrBase, fatPtrOffset) -> (fatPtrBase, slice(fatPtrOffset))
640
+ class ConvertExtractSliceOp
641
+ : public PointerCanonicalizationPattern<tt::amdgpu::ExtractSliceOp> {
642
+ public:
643
+ using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
644
+
645
+ LogicalResult
646
+ matchAndRewrite_ (tt::amdgpu::ExtractSliceOp extractSliceOp,
647
+ OneToNOpAdaptor adaptor,
648
+ ConversionPatternRewriter &rewriter) const override {
649
+ ValueRange remappedOperands = adaptor.getSource ();
650
+ if (remappedOperands.size () != 2 )
651
+ return success ();
652
+
653
+ Value fatPtrBase = remappedOperands[0 ];
654
+ Value fatPtrOffset = remappedOperands[1 ];
655
+ if (!llvm::isa<tt::PointerType>(fatPtrBase.getType ()))
656
+ return rewriter.notifyMatchFailure (extractSliceOp,
657
+ " non tt.ptr base unimplemented" );
658
+
659
+ auto fatPtrOffsetTy = dyn_cast<RankedTensorType>(fatPtrOffset.getType ());
660
+ if (!fatPtrOffsetTy)
661
+ return rewriter.notifyMatchFailure (
662
+ extractSliceOp, " non RankedTensorType offset unimplemented" );
663
+
664
+ Location loc = extractSliceOp->getLoc ();
665
+ RankedTensorType resultType = extractSliceOp.getResult ().getType ();
666
+ auto slicedOffsetsTy = RankedTensorType::get (
667
+ resultType.getShape (), fatPtrOffsetTy.getElementType (),
668
+ resultType.getEncoding ());
669
+ Value slicedOffsets = rewriter.create <tt::amdgpu::ExtractSliceOp>(
670
+ loc, Type{slicedOffsetsTy}, Value{fatPtrOffset},
671
+ extractSliceOp.getStaticOffsetsAttr ());
672
+
673
+ rewriter.replaceOpWithMultiple (extractSliceOp,
674
+ {{fatPtrBase, slicedOffsets}});
675
+ fatPtrs[{fatPtrBase, slicedOffsets}] =
676
+ fatPtrs.at ({fatPtrBase, fatPtrOffset});
677
+
678
+ return success ();
679
+ }
680
+ };
681
+
638
682
// / Rewrite init args and result type and bb args.
639
683
class ConvertSCFForOp : public PointerCanonicalizationPattern <scf::ForOp> {
640
684
using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
@@ -1510,6 +1554,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
1510
1554
target.addDynamicallyLegalDialect <scf::SCFDialect>(isLegal);
1511
1555
target.addDynamicallyLegalDialect <cf::ControlFlowDialect>(isLegal);
1512
1556
target.addDynamicallyLegalDialect <arith::ArithDialect>(isLegal);
1557
+ target.addDynamicallyLegalDialect <triton::amdgpu::TritonAMDGPUDialect>(
1558
+ isLegal);
1513
1559
1514
1560
// Rewrite the rest of the ops.
1515
1561
// Note we *do not* declare unrealized_cast an illegal op here in order that
@@ -1521,7 +1567,7 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
1521
1567
RewritePatternSet patterns (&getContext ());
1522
1568
patterns.add <
1523
1569
ConvertFuncOpArgsUnrealizedCasts, ConvertBroadcastOp, ConvertSplatOp,
1524
- ConvertConvertLayoutOp, ConvertAddPtrOp,
1570
+ ConvertConvertLayoutOp, ConvertAddPtrOp, ConvertExtractSliceOp,
1525
1571
MaterializeFatPointer<tt::AtomicCASOp>,
1526
1572
MaterializeFatPointer<tt::AtomicRMWOp>,
1527
1573
MaterializeFatPointer<tt::BitcastOp>, MaterializeFatPointer<tt::LoadOp>,
0 commit comments