@@ -1654,8 +1654,10 @@ static const std::string kInitFuncArgsRewritten =
16541654// / (ConvertUnimplementedOpUnrealizedCasts) if it wasn't DCEd (via a user
16551655// / extracting the tt.ptr and c0 operands).
16561656struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
1657- InitFuncPtrArgs (MLIRContext *context, FatPointers &fatPtrs)
1658- : OpRewritePattern(context, 0 ), fatPtrs(fatPtrs) {}
1657+ InitFuncPtrArgs (MLIRContext *context, FatPointers &fatPtrs,
1658+ bool enableLargeTensorPtrCanon_)
1659+ : OpRewritePattern(context, 0 ), fatPtrs(fatPtrs),
1660+ enableLargeTensorPtrCanon (enableLargeTensorPtrCanon_) {}
16591661
16601662 LogicalResult matchAndRewrite (tt::FuncOp newOp,
16611663 PatternRewriter &rewriter) const override {
@@ -1673,7 +1675,11 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
16731675 newOp.getArgAttrOfType <IntegerAttr>(idx, " tt.pointer_range" ))
16741676 bitness = pointerRangeAttr.getInt ();
16751677
1676- LDBG (idx << " -th argument: " << arg << " , bitness: " << bitness << " \n " );
1678+ LDBG (idx << " -th argument: " << arg << " , bitness: " << bitness);
1679+ if (!enableLargeTensorPtrCanon && (bitness == 64 )) {
1680+ LDBG (" Do not init argument of large-tensor pointer: " << arg);
1681+ continue ;
1682+ }
16771683
16781684 Value zeroOffset =
16791685 rewriter.create <arith::ConstantIntOp>(newOp.getLoc (), 0 , bitness);
@@ -1690,6 +1696,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
16901696 }
16911697
16921698 FatPointers &fatPtrs;
1699+ bool enableLargeTensorPtrCanon;
16931700};
16941701
16951702// / No-op to make conversion framework happy.
@@ -1816,6 +1823,8 @@ class ConvertUnimplementedOpUnrealizedCasts
18161823class TritonAMDGPUCanonicalizePointersPass
18171824 : public impl::TritonAMDGPUCanonicalizePointersBase<
18181825 TritonAMDGPUCanonicalizePointersPass> {
1826+ using Base::Base;
1827+
18191828public:
18201829 void runOnOperation () override ;
18211830};
@@ -1905,18 +1914,29 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
19051914 FatPointers fatPrs;
19061915 PatternRewriter rewriter (&getContext ());
19071916 // Convert tt.func; %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr
1908- InitFuncPtrArgs pat (&getContext (), fatPrs);
1917+ InitFuncPtrArgs pat (&getContext (), fatPrs, enableLargeTensorPtrCanon );
19091918 if (failed (pat.matchAndRewrite (func, rewriter)))
19101919 return signalPassFailure ();
19111920
19121921 llvm::SetVector<Operation *> opsToRewrite;
1913- for (auto arg : func.getArguments ()) {
1914- if (llvm::isa<tt::PointerType>(arg.getType ())) {
1915- // NB: reusing the same SetVector invalidates the topo order implied by
1916- // getForwardSlice
1917- for (auto &use : arg.getUses ())
1918- getForwardSliceImpl (&use, use.getOwner (), &opsToRewrite);
1922+ for (auto [idx, arg] : llvm::enumerate (func.getArguments ())) {
1923+ if (!llvm::isa<tt::PointerType>(arg.getType ()))
1924+ continue ;
1925+
1926+ int64_t bitness = 64 ;
1927+ if (auto pointerRangeAttr =
1928+ func.getArgAttrOfType <IntegerAttr>(idx, " tt.pointer_range" ))
1929+ bitness = pointerRangeAttr.getInt ();
1930+
1931+ if (!enableLargeTensorPtrCanon && (bitness == 64 )) {
1932+ LDBG (" ignore " << idx << " -th argument of large-tensor ptr: " << arg);
1933+ continue ;
19191934 }
1935+
1936+ // NB: reusing the same SetVector invalidates the topo order implied by
1937+ // getForwardSlice
1938+ for (auto &use : arg.getUses ())
1939+ getForwardSliceImpl (&use, use.getOwner (), &opsToRewrite);
19201940 }
19211941
19221942 ConversionConfig config;
0 commit comments