Skip to content

Commit eeb07f7

Browse files
Merge commit 'ea4bdaf9d662e36a52ea422a37daa4e2e1abad30'
2 parents 586824a + ea4bdaf commit eeb07f7

File tree

10 files changed

+63
-32
lines changed

10 files changed

+63
-32
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,15 @@ struct MemDescIndexOpConversion
493493
auto prevOffsets = smemObj.getOffsets();
494494
SmallVector<Value> offsetVals(prevOffsets.end() - dstTy.getRank(),
495495
prevOffsets.end());
496+
497+
// Apply padding based on the amount we move the base ptr
498+
if (auto padEnc = dyn_cast<PaddedSharedEncodingAttr>(dstTy.getEncoding())) {
499+
auto bitwidth = dstTy.getElementTypeBitWidth();
500+
Value padOffset = emitPadding(loc, rewriter, padEnc, bitwidth, offset,
501+
/*offsetInBytes=*/false);
502+
offset = b.add(offset, padOffset);
503+
}
504+
496505
// Advance the pointer and keep the opOffsets as the new shape
497506
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
498507
llvmElemTy, offsetVals);

python/test/unit/runtime/test_bindings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def walk_fn(op):
7777
triton._C.libtriton.ir.load_dialects(context)
7878
backend.load_dialects(context)
7979

80-
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
80+
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
8181
ttir_module.walk(walk_fn)
8282

8383

python/triton/_filecheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def run_parser(kernel_fn):
6868
options = stub_backend.parse_options(options)
6969
codegen_fns = stub_backend.get_codegen_implementation(options)
7070
module_map = stub_backend.get_module_map()
71-
module = src.make_ir(options, codegen_fns, module_map, context)
71+
module = src.make_ir(stub_target, options, codegen_fns, module_map, context)
7272
assert module.verify()
7373
return module
7474

python/triton/compiler/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def hash(self):
7777
key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
7878
return hashlib.sha256(key.encode("utf-8")).hexdigest()
7979

80-
def make_ir(self, options, codegen_fns, module_map, context):
80+
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
8181
from .code_generator import ast_to_ttir
8282
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
8383
module_map=module_map)
@@ -116,7 +116,7 @@ def __init__(self, path, context, backend):
116116
def hash(self):
117117
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
118118

119-
def make_ir(self, options, codegen_fns, module_map, context):
119+
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
120120
self.module.context = context
121121
return self.module
122122

@@ -303,7 +303,7 @@ def compile(src, target=None, options=None, _env_vars=None):
303303
codegen_fns = backend.get_codegen_implementation(options)
304304
module_map = backend.get_module_map()
305305
try:
306-
module = src.make_ir(options, codegen_fns, module_map, context)
306+
module = src.make_ir(target, options, codegen_fns, module_map, context)
307307
except Exception as e:
308308
filter_traceback(e)
309309
raise

python/triton/experimental/gluon/_runtime.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
import triton
32
from triton.compiler.compiler import ASTSource
43
from triton.backends.compiler import Language
54
from triton.runtime.jit import JITFunction
@@ -16,15 +15,14 @@ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
1615
self.language = Language.GLUON
1716
self.ext = "ttgir"
1817

19-
def make_ir(self, options, codegen_fns, module_map, context):
18+
def make_ir(self, target, options, codegen_fns, module_map, context):
2019
from triton.compiler.compiler import make_backend
2120
from triton.compiler.code_generator import ast_to_ttir
2221

2322
builder = ir.builder(context)
2423
module = builder.create_module()
2524

2625
# Assign module attributes eagerly, as they are needed to verify layouts
27-
target = triton.runtime.driver.active.get_current_target()
2826
backend = make_backend(target)
2927
target = backend.get_target_name(options)
3028

python/triton/runtime/_allocation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional, Protocol
2+
from contextvars import ContextVar
23

34

45
class Buffer(Protocol):
@@ -20,7 +21,7 @@ def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
2021
"Use triton.set_allocator to specify an allocator.")
2122

2223

23-
_allocator: Allocator = NullAllocator()
24+
_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
2425

2526

2627
def set_allocator(allocator: Allocator):
@@ -29,4 +30,4 @@ def set_allocator(allocator: Allocator):
2930
require additional global memory workspace.
3031
"""
3132
global _allocator
32-
_allocator = allocator
33+
_allocator.set(allocator)

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -412,34 +412,30 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
412412

413413
// -----
414414

415-
// CHECK-LABEL: padded_shared_layout_subview
415+
// GFX950-LABEL: padded_shared_layout_subview
416416
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
417-
#shared = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}>
417+
#shared = #ttg.padded_shared<[128:+4] {order = [1, 0]}>
418418
#smem = #ttg.shared_memory
419419
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
420420
tt.func @padded_shared_layout_subview(%arg0: !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable>) {
421421
%c0_i32 = arith.constant 0 : i32
422422
%c1_i32 = arith.constant 1 : i32
423-
// Skip two constants from the stride calculation
423+
// Skip three constants from the stride calculation
424+
// GFX950: llvm.mlir.constant
425+
// GFX950: llvm.mlir.constant
426+
// GFX950: llvm.mlir.constant
424427

425-
// CHECK-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32)
426-
// CHECK-DAG: %[[CST3:.+]] = llvm.mlir.constant(3 : i32)
427-
// CHECK-DAG: %[[CST4:.+]] = llvm.mlir.constant(4 : i32)
428-
// CHECK-DAG: %[[CST8:.+]] = llvm.mlir.constant(8 : i32)
429-
// CHECK-DAG: %[[CST9:.+]] = llvm.mlir.constant(9 : i32)
428+
// GFX950-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32)
429+
// GFX950-DAG: %[[CST7:.+]] = llvm.mlir.constant(7 : i32)
430+
// GFX950-DAG: %[[CST2:.+]] = llvm.mlir.constant(2 : i32)
430431

431-
// CHECK: %[[SHR0:.+]] = llvm.ashr %[[ADD:.+]], %[[CST8]] : i32
432-
// CHECK-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST3]] : i32
433-
// CHECK-NEXT: %[[ADD0:.+]] = llvm.add %[[SHL0]], %[[CST0]] : i32
434-
// CHECK-NEXT: %[[SHR1:.+]] = llvm.ashr %[[ADD]], %[[CST9]] : i32
435-
// CHECK-NEXT: %[[SHL1:.+]] = llvm.shl %[[SHR1]], %[[CST4]] : i32
436-
// CHECK-NEXT: %[[ADD1:.+]] = llvm.add %[[ADD0]], %[[SHL1]] : i32
437-
// CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[ADD]], %[[ADD1]] : i32
438-
// CHECK: llvm.getelementptr inbounds %{{.+}}[%[[ADD2]]]
432+
// GFX950: %[[SHR0:.+]] = llvm.ashr %[[ADD:.+]], %[[CST7]] : i32
433+
// GFX950-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST2]] : i32
434+
// GFX950-NEXT: %[[ADD1:.+]] = llvm.add %[[CST0]], %[[SHL0]] : i32
435+
// GFX950-NEXT: %[[ADD2:.+]] = llvm.add %[[ADD]], %[[ADD1]] : i32
436+
// GFX950: llvm.getelementptr %{{.+}}[%[[ADD2]]]
439437

440438
%1 = ttg.memdesc_index %arg0, %c1_i32 : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
441-
%2 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
442-
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
443439
tt.return
444440
}
445441
}

test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
137137
#smem = #ttg.shared_memory
138138
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
139139

140-
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster
140+
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster_1
141141

142142
// CHECK-NOT: setprio
143143
// CHECK-NOT: barrier
144144

145-
tt.func @reject_chained_dots_empty_mem_cluster(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
145+
tt.func @reject_chained_dots_empty_mem_cluster_1(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
146146
%c1_i32 = arith.constant 1 : i32
147147
%c0_i32 = arith.constant 0 : i32
148148
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
@@ -164,3 +164,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
164164
tt.return %5#0 : tensor<128x16xf32, #mma>
165165
}
166166
}
167+
168+
// -----
169+
170+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
171+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
172+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 8, order = [0, 1]}>
173+
#smem = #ttg.shared_memory
174+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
175+
176+
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster_2
177+
178+
// CHECK-NOT: setprio
179+
// CHECK-NOT: barrier
180+
181+
tt.func @reject_chained_dots_empty_mem_cluster_2(%memdesc1: !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, %memdesc2: !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, %alloc1: !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>, %alloc2: !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>, %arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
182+
%5:8 = scf.for %arg14 = %arg3 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %memdesc1, %arg19 = %memdesc1, %arg20 = %memdesc2, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 {
183+
%6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
184+
ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
185+
%11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
186+
%13 = tt.load %arg1 : tensor<64x16x!tt.ptr<f16>, #blocked>
187+
%10 = tt.dot %arg10, %arg17, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
188+
scf.yield %10, %6, %11, %arg19, %arg20, %arg20, %13, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>
189+
}
190+
tt.return %5#0 : tensor<128x16xf32, #mma>
191+
}
192+
}

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ LogicalResult Pingponger::transformChainedDotSchedule(OpBuilder &builder,
678678

679679
// Memory clusters start with either ttg.async_wait or ttg.local_store
680680
auto findNextMemoryCluster = [](Operation *op) {
681-
while (!llvm::isa_and_nonnull<ttg::AsyncWaitOp, ttg::LocalStoreOp>(op)) {
681+
while (op && !llvm::isa<ttg::AsyncWaitOp, ttg::LocalStoreOp>(op)) {
682682
op = op->getNextNode();
683683
}
684684
return op;

third_party/nvidia/backend/driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,8 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
727727
if self.global_scratch_size > 0:
728728
grid_size = gridX * gridY * gridZ
729729
alloc_size = grid_size * self.num_ctas * self.global_scratch_size
730-
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
730+
alloc_fn = _allocation._allocator.get()
731+
global_scratch = alloc_fn(alloc_size, self.global_scratch_align, stream)
731732
else:
732733
global_scratch = None
733734
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,

0 commit comments

Comments
 (0)