Skip to content

Commit 2f40358

Browse files
authored
[TMA] Move ptxas bug workaround to TMAToLLVM and gate on ptx version (#5408)
There is a bug in the version of `ptxas` we have which treats `global_stride` values as if they are 16-byte strides instead of single byte strides (presumably because the tensormap structure packs them in this format). This centralizes the workaround to one point in the code and gates it on the current ptx version.
1 parent 3563aec commit 2f40358

File tree

7 files changed

+238
-244
lines changed

7 files changed

+238
-244
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ using namespace mlir::triton;
5353
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
5454
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
5555
#define lshr(...) rewriter.create<LLVM::LShrOp>(loc, __VA_ARGS__)
56+
#define ashr(...) rewriter.create<LLVM::AShrOp>(loc, __VA_ARGS__)
5657
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
5758
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
5859
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
5252
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
5353
Value globalStride = builder.template create<arith::MulIOp>(
5454
loc, op.getStrides()[0], elemSizeVal);
55-
// TODO: Workaround for ptxas bug, remove when we update ptxas
56-
Value four = builder.template create<arith::ConstantOp>(
57-
loc, builder.getI64Type(), builder.getI64IntegerAttr(4));
58-
globalStride =
59-
builder.template create<arith::ShRSIOp>(loc, globalStride, four);
6055

6156
int elemTypeEnum;
6257
switch (elemSize) {

test/TritonGPU/samples/simulated-grouped-gemm.mlir

Lines changed: 227 additions & 234 deletions
Large diffs are not rendered by default.

test/TritonNvidiaGPU/tma_lowering.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
3838
// CHECK-LABEL: make_tensor_descriptor
3939
// CHECK: %0 = arith.extsi %arg2 : i32 to i64
4040
// CHECK: %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr<i8>
41-
// CHECK: %2 = arith.shrsi %0, %c4_i64 : i64
42-
// CHECK: tt.experimental_tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%2], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr<i8>, !tt.ptr<i8>, i32, i32, i32, i32, i64, i32, i32) -> ()
41+
// CHECK: tt.experimental_tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%0], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr<i8>, !tt.ptr<i8>, i32, i32, i32, i32, i64, i32, i32) -> ()
4342
// CHECK: tt.experimental_tensormap_fenceproxy_acquire %1 : !tt.ptr<i8>
4443
// CHECK: tt.reinterpret_tensor_descriptor %1 : !tt.ptr<i8> to !tt.tensordesc<tensor<8x32xi8>>
4544
tt.func public @make_tensor_descriptor(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.tensordesc<tensor<8x32xi8>> {

third_party/nvidia/language/cuda/_experimental_tma.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def experimental_device_tensormap_create2d(
6868
element_size = element_ty.primitive_bitwidth // 8
6969
element_size_t = core.full([], element_size, core.int64, _builder=_builder)
7070
global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder)
71-
# Undocumented, but global_stride seems to be divided by 16
72-
global_stride = semantic.ashr(global_stride, semantic.to_tensor(4, _builder), _builder)
7371

7472
contig_dim_size_in_bytes = element_size * load_size[-1]
7573
if contig_dim_size_in_bytes > 128:

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ struct ExperimentalTensormapCreateOpConversion
249249
Location loc = op->getLoc();
250250
auto ctx = getContext();
251251

252+
bool needsStrideWorkaround = targetInfo.getPtxVersion() <= 85;
252253
auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
253254

254255
zero_fill_tma(loc, ctx, rewriter, targetInfo, smemBase);
@@ -264,8 +265,13 @@ struct ExperimentalTensormapCreateOpConversion
264265
op.getGlobalDim()[i]);
265266
}
266267
for (int i = 0; i + 1 < op.getRank(); ++i) {
268+
auto strideVal = op.getGlobalStride()[i];
269+
if (needsStrideWorkaround) {
270+
// Workaround for a ptxas bug
271+
strideVal = ashr(strideVal, i64_val(4));
272+
}
267273
tensormap_replace_global_stride(loc, ctx, rewriter, smemBase, i,
268-
op.getGlobalStride()[i]);
274+
strideVal);
269275
}
270276
for (int i = 0; i < op.getRank(); ++i) {
271277
tensormap_replace_element_stride(loc, ctx, rewriter, smemBase, i,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6262

6363
bool supportVectorizedAtomics() const override;
6464

65+
int getPtxVersion() const { return ptxVersion; }
66+
6567
private:
6668
int computeCapability;
6769
int ptxVersion;

0 commit comments

Comments
 (0)