Skip to content

Commit 88b8a5c

Browse files
yangshuxinShuxin Yang
andauthored
[AMD] disable pointer-canonicalization for large-tensor (#8359)
This commit disables pointer-canonicalization for pointer pointing to large tensors. The large tensors refers to JIT specialization for those tensor argument over 2GB. It is disabled on the ground that is has some tricky bugs. We are trying to come up a better approach that address several conflicting performance aspects. --------- Co-authored-by: Shuxin Yang <[email protected]>
1 parent d5156d7 commit 88b8a5c

File tree

5 files changed

+57
-12
lines changed

5 files changed

+57
-12
lines changed

test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// NOTE: Assertions have been autogenerated by mlir/utils/generate-test-checks.py
22

3-
// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers -verify-diagnostics | FileCheck %s
3+
// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers="enable-large-tensor-ptr-canon=true" -verify-diagnostics | FileCheck %s
44

55
module attributes {"ttg.num-warps" = 4 : i32} {
66
tt.func @ifOpTwoYields(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>, %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>) {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers="enable-large-tensor-ptr-canon=false" -canonicalize -verify-diagnostics | FileCheck %s
2+
3+
// this case is copied from amd-canonicalize-pointers-no-large-tensor.mlir. With
4+
// enable-large-tensor-ptr-canon=false, the input is not changed at all.
5+
module attributes {"ttg.num-warps" = 4 : i32} {
6+
tt.func @conversion1(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
7+
%c1024_i32 = arith.constant 1024 : i32
8+
%0 = tt.get_program_id x : i32
9+
%1 = arith.muli %0, %c1024_i32 : i32
10+
%2 = tt.splat %1 : i32 -> tensor<1024xi32>
11+
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
12+
%4 = tt.addptr %3, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
13+
%5 = tt.load %4 : tensor<1024x!tt.ptr<f32>>
14+
tt.return %5 : tensor<1024xf32>
15+
}
16+
}
17+
18+
// CHECK-LABEL: tt.func @conversion1
19+
// CHECK: %[[ADDPTR:.*]] = tt.addptr
20+
// CHECK: = tt.load %[[ADDPTR]]

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// NOTE: Assertions have been autogenerated by mlir/utils/generate-test-checks.py
22

3-
// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize -verify-diagnostics | FileCheck %s
3+
// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers="enable-large-tensor-ptr-canon=true" -canonicalize -verify-diagnostics | FileCheck %s
44

55
module attributes {"ttg.num-warps" = 4 : i32} {
66
tt.func @conversion1(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers"
123123

124124
let dependentDialects = [];
125125

126+
let options = [
127+
Option<"enableLargeTensorPtrCanon", "enable-large-tensor-ptr-canon",
128+
"bool", /*default=*/"false",
129+
"Whether to enable canonicalization for pointers pointing to large-tensors (a specialization for tensors over 2GB)">
130+
];
126131
}
127132

128133
def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> {

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,8 +1654,10 @@ static const std::string kInitFuncArgsRewritten =
16541654
/// (ConvertUnimplementedOpUnrealizedCasts) if it wasn't DCEd (via a user
16551655
/// extracting the tt.ptr and c0 operands).
16561656
struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
1657-
InitFuncPtrArgs(MLIRContext *context, FatPointers &fatPtrs)
1658-
: OpRewritePattern(context, 0), fatPtrs(fatPtrs) {}
1657+
InitFuncPtrArgs(MLIRContext *context, FatPointers &fatPtrs,
1658+
bool enableLargeTensorPtrCanon_)
1659+
: OpRewritePattern(context, 0), fatPtrs(fatPtrs),
1660+
enableLargeTensorPtrCanon(enableLargeTensorPtrCanon_) {}
16591661

16601662
LogicalResult matchAndRewrite(tt::FuncOp newOp,
16611663
PatternRewriter &rewriter) const override {
@@ -1673,7 +1675,11 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
16731675
newOp.getArgAttrOfType<IntegerAttr>(idx, "tt.pointer_range"))
16741676
bitness = pointerRangeAttr.getInt();
16751677

1676-
LDBG(idx << "-th argument: " << arg << ", bitness: " << bitness << "\n");
1678+
LDBG(idx << "-th argument: " << arg << ", bitness: " << bitness);
1679+
if (!enableLargeTensorPtrCanon && (bitness == 64)) {
1680+
LDBG("Do not init argument of large-tensor pointer: " << arg);
1681+
continue;
1682+
}
16771683

16781684
Value zeroOffset =
16791685
rewriter.create<arith::ConstantIntOp>(newOp.getLoc(), 0, bitness);
@@ -1690,6 +1696,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
16901696
}
16911697

16921698
FatPointers &fatPtrs;
1699+
bool enableLargeTensorPtrCanon;
16931700
};
16941701

16951702
/// No-op to make conversion framework happy.
@@ -1816,6 +1823,8 @@ class ConvertUnimplementedOpUnrealizedCasts
18161823
class TritonAMDGPUCanonicalizePointersPass
18171824
: public impl::TritonAMDGPUCanonicalizePointersBase<
18181825
TritonAMDGPUCanonicalizePointersPass> {
1826+
using Base::Base;
1827+
18191828
public:
18201829
void runOnOperation() override;
18211830
};
@@ -1905,18 +1914,29 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
19051914
FatPointers fatPrs;
19061915
PatternRewriter rewriter(&getContext());
19071916
// Convert tt.func; %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr
1908-
InitFuncPtrArgs pat(&getContext(), fatPrs);
1917+
InitFuncPtrArgs pat(&getContext(), fatPrs, enableLargeTensorPtrCanon);
19091918
if (failed(pat.matchAndRewrite(func, rewriter)))
19101919
return signalPassFailure();
19111920

19121921
llvm::SetVector<Operation *> opsToRewrite;
1913-
for (auto arg : func.getArguments()) {
1914-
if (llvm::isa<tt::PointerType>(arg.getType())) {
1915-
// NB: reusing the same SetVector invalidates the topo order implied by
1916-
// getForwardSlice
1917-
for (auto &use : arg.getUses())
1918-
getForwardSliceImpl(&use, use.getOwner(), &opsToRewrite);
1922+
for (auto [idx, arg] : llvm::enumerate(func.getArguments())) {
1923+
if (!llvm::isa<tt::PointerType>(arg.getType()))
1924+
continue;
1925+
1926+
int64_t bitness = 64;
1927+
if (auto pointerRangeAttr =
1928+
func.getArgAttrOfType<IntegerAttr>(idx, "tt.pointer_range"))
1929+
bitness = pointerRangeAttr.getInt();
1930+
1931+
if (!enableLargeTensorPtrCanon && (bitness == 64)) {
1932+
LDBG("ignore " << idx << "-th argument of large-tensor ptr: " << arg);
1933+
continue;
19191934
}
1935+
1936+
// NB: reusing the same SetVector invalidates the topo order implied by
1937+
// getForwardSlice
1938+
for (auto &use : arg.getUses())
1939+
getForwardSliceImpl(&use, use.getOwner(), &opsToRewrite);
19201940
}
19211941

19221942
ConversionConfig config;

0 commit comments

Comments
 (0)