Skip to content

Commit 4593bcd

Browse files
authored
Support IR (ttir/ttgir/llir/ptx) overriding at scale by Triton autotune configs (#6802)
This is an alternative to triton-lang/triton#6790 where we tried to allow users to override kernels with customized PTX source files. Unlike triton-lang/triton#6790 where some missing metadata are set in a brittle way, this PR simplifies the logic by adding `ptx_override` to Triton autotune config. We don't lose the legit metadata produced by `make_ttgir` / `make_llir` while users can still load their customized PTX files at scale as below. ``` bwd_configs_ws = [ ( triton_config( { "BLOCK_M1": BM1, "BLOCK_N1": BN1, "BLOCK_M2": BN1, "BLOCK_N2": BM1, "NUM_CONSUMER_GROUPS": 2, }, num_stages=s, num_warps=w, num_buffers_warp_spec=buf, num_consumer_groups=2, ptx_override=f"/data/users/daohang/test_bwd/customized_block_n1_{BN1}.ptx" ) ) for buf in [2] for BM1 in [64] for BN1 in [64,128] for s in ([1, 2] if is_hip() else [0]) for w in [4] ] ``` Besides PTX, ttir/ttgir/llir are also supported. See `pytest python/test/unit/runtime/test_autotuner.py -k test_override` for more details.
1 parent 7dc5492 commit 4593bcd

File tree

4 files changed

+213
-2
lines changed

4 files changed

+213
-2
lines changed

python/test/unit/runtime/test_autotuner.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import triton.language as tl
55
import pytest
66

7+
import pathlib
8+
import uuid
9+
from triton._internal_testing import is_cuda
10+
711

812
def do_bench(kernel_call, quantiles, use_cuda_graph=False):
913
if use_cuda_graph:
@@ -169,6 +173,203 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
169173
assert records['capture_named_args']
170174

171175

176+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9,
177+
reason="Requires compute capability >= 9 for NV")
178+
def test_override_ttir(device):
179+
N = 1024
180+
src = torch.randn(N, device=device)
181+
dst = torch.empty(N, device=device)
182+
183+
ir_src = r"""
184+
module {
185+
tt.func public @_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
186+
%cst = arith.constant dense<1.000000e+01> : tensor<32xf32>
187+
%c32_i32 = arith.constant 32 : i32
188+
%0 = tt.get_program_id x : i32
189+
%1 = arith.muli %0, %c32_i32 : i32
190+
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
191+
%3 = tt.splat %1 : i32 -> tensor<32xi32>
192+
%4 = arith.addi %3, %2 : tensor<32xi32>
193+
%5 = tt.splat %arg2 : i32 -> tensor<32xi32>
194+
%6 = arith.cmpi slt, %4, %5 : tensor<32xi32>
195+
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
196+
%8 = tt.addptr %7, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
197+
%9 = tt.load %8, %6 : tensor<32x!tt.ptr<f32>>
198+
%10 = arith.mulf %9, %cst : tensor<32xf32>
199+
%11 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
200+
%12 = tt.addptr %11, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
201+
tt.store %12, %10, %6 : tensor<32x!tt.ptr<f32>>
202+
tt.return
203+
}
204+
}
205+
"""
206+
temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ttir")
207+
temp_file.write_text(ir_src)
208+
209+
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})]
210+
211+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
212+
@triton.jit
213+
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
214+
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
215+
x = tl.load(src + offsets, mask=offsets < N)
216+
tl.store(dst + offsets, x, mask=offsets < N)
217+
218+
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
219+
_kernel[grid](dst, src, N=N)
220+
221+
# Change the behavior of kernel by overriding PTX
222+
torch.testing.assert_close(src * 10, dst)
223+
224+
225+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9,
226+
reason="Requires compute capability >= 9 for NV")
227+
def test_override_ttgir(device):
228+
N = 1024
229+
src = torch.randn(N, device=device)
230+
dst = torch.empty(N, device=device)
231+
232+
ir_src = r"""
233+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
234+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
235+
tt.func public @_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
236+
%cst = arith.constant dense<1.000000e+01> : tensor<32xf32, #blocked>
237+
%c32_i32 = arith.constant 32 : i32
238+
%0 = tt.get_program_id x : i32
239+
%1 = arith.muli %0, %c32_i32 : i32
240+
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked>
241+
%3 = tt.splat %1 : i32 -> tensor<32xi32, #blocked>
242+
%4 = arith.addi %3, %2 : tensor<32xi32, #blocked>
243+
%5 = tt.splat %arg2 : i32 -> tensor<32xi32, #blocked>
244+
%6 = arith.cmpi slt, %4, %5 : tensor<32xi32, #blocked>
245+
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>, #blocked>
246+
%8 = tt.addptr %7, %4 : tensor<32x!tt.ptr<f32>, #blocked>, tensor<32xi32, #blocked>
247+
%9 = tt.load %8, %6 : tensor<32x!tt.ptr<f32>, #blocked>
248+
%10 = arith.mulf %9, %cst : tensor<32xf32, #blocked>
249+
%11 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>, #blocked>
250+
%12 = tt.addptr %11, %4 : tensor<32x!tt.ptr<f32>, #blocked>, tensor<32xi32, #blocked>
251+
tt.store %12, %10, %6 : tensor<32x!tt.ptr<f32>, #blocked>
252+
tt.return
253+
}
254+
}
255+
"""
256+
temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ttgir")
257+
temp_file.write_text(ir_src)
258+
259+
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})]
260+
261+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
262+
@triton.jit
263+
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
264+
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
265+
x = tl.load(src + offsets, mask=offsets < N)
266+
tl.store(dst + offsets, x, mask=offsets < N)
267+
268+
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
269+
_kernel[grid](dst, src, N=N)
270+
271+
# Change the behavior of kernel by overriding PTX
272+
torch.testing.assert_close(src * 10, dst)
273+
274+
275+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9,
276+
reason="PTX file in this unit test is only for SM90")
277+
def test_override_ptx(device):
278+
N = 1024
279+
src = torch.randn(N, device=device)
280+
dst = torch.empty(N, device=device)
281+
282+
ir_src = r"""
283+
//
284+
// Generated by LLVM NVPTX Back-End
285+
//
286+
287+
.version 8.7
288+
.target sm_90a
289+
.address_size 64
290+
291+
// .globl _kernel // -- Begin function _kernel
292+
// @_kernel
293+
.visible .entry _kernel(
294+
.param .u64 .ptr .global .align 1 _kernel_param_0,
295+
.param .u64 .ptr .global .align 1 _kernel_param_1,
296+
.param .u32 _kernel_param_2,
297+
.param .u64 .ptr .global .align 1 _kernel_param_3
298+
)
299+
.reqntid 128
300+
{
301+
.reg .pred %p<4>;
302+
.reg .b32 %r<10>;
303+
.reg .b32 %f<3>;
304+
.reg .b64 %rd<6>;
305+
.loc 1 180 0
306+
$L__func_begin0:
307+
.loc 1 180 0
308+
309+
// %bb.0:
310+
ld.param.u64 %rd3, [_kernel_param_0];
311+
ld.param.u64 %rd4, [_kernel_param_1];
312+
$L__tmp0:
313+
.loc 1 181 28
314+
mov.u32 %r3, %ctaid.x;
315+
.loc 1 181 33
316+
shl.b32 %r4, %r3, 5;
317+
ld.param.u32 %r5, [_kernel_param_2];
318+
.loc 1 181 59
319+
mov.u32 %r6, %tid.x;
320+
and.b32 %r7, %r6, 31;
321+
.loc 1 181 46
322+
or.b32 %r8, %r4, %r7;
323+
.loc 1 182 46
324+
setp.lt.s32 %p1, %r8, %r5;
325+
.loc 1 182 22
326+
mul.wide.s32 %rd5, %r8, 4;
327+
add.s64 %rd1, %rd4, %rd5;
328+
.loc 1 182 16
329+
// begin inline asm
330+
mov.u32 %r1, 0x0;
331+
@%p1 ld.global.b32 { %r1 }, [ %rd1 + 0 ];
332+
// end inline asm
333+
mov.b32 %f1, %r1;
334+
.loc 1 183 12
335+
mul.f32 %f2, %f1, 0f41200000;
336+
.loc 1 184 19
337+
add.s64 %rd2, %rd3, %rd5;
338+
.loc 1 184 28
339+
and.b32 %r9, %r6, 96;
340+
setp.eq.s32 %p3, %r9, 0;
341+
mov.b32 %r2, %f2;
342+
and.pred %p2, %p3, %p1;
343+
// begin inline asm
344+
@%p2 st.global.b32 [ %rd2 + 0 ], { %r2 };
345+
// end inline asm
346+
.loc 1 184 4
347+
ret;
348+
$L__tmp1:
349+
$L__func_end0:
350+
// -- End function
351+
}
352+
"""
353+
temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ptx")
354+
temp_file.write_text(ir_src)
355+
356+
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})]
357+
358+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
359+
@triton.jit
360+
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
361+
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
362+
x = tl.load(src + offsets, mask=offsets < N)
363+
x = x * 10
364+
tl.store(dst + offsets, x, mask=offsets < N)
365+
366+
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
367+
_kernel[grid](dst, src, N=N)
368+
369+
# Change the behavior of kernel by overriding PTX
370+
torch.testing.assert_close(src * 10, dst)
371+
372+
172373
def test_exceed_tmem(device):
173374
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10:
174375
pytest.skip("Test requires tensor memory.")

python/triton/compiler/compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,12 @@ def compile(src, target=None, options=None):
348348
for ext, compile_ir in list(stages.items())[first_stage:]:
349349
next_module = compile_ir(module, metadata)
350350
ir_filename = f"{file_name}.{ext}"
351-
if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
351+
if fn_override_manager is None:
352+
# Users can override kernels at scale by setting `ir_override` in autotune config
353+
# without TRITON_KERNEL_OVERRIDE
354+
if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
355+
next_module = parse(ir_override, ext, context)
356+
elif full_name := fn_override_manager.get_file(ir_filename):
352357
print(f"\nOverriding kernel with file {full_name}")
353358
next_module = parse(full_name, ext, context)
354359
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json

python/triton/runtime/autotuner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,17 @@ class Config:
313313
to ptx .maxnreg directive. Not supported on all platforms.
314314
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
315315
function are args.
316+
:ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
316317
"""
317318

318-
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None):
319+
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
319320
self.kwargs = kwargs
320321
self.num_warps = num_warps
321322
self.num_ctas = num_ctas
322323
self.num_stages = num_stages
323324
self.maxnreg = maxnreg
324325
self.pre_hook = pre_hook
326+
self.ir_override = ir_override
325327

326328
def __setstate__(self, state):
327329
self.kwargs = state.get("kwargs", {})
@@ -330,6 +332,7 @@ def __setstate__(self, state):
330332
self.num_ctas = state.get("num_ctas", 1)
331333
self.maxnreg = state.get("maxnreg", None)
332334
self.pre_hook = state.get("pre_hook", None)
335+
self.ir_override = state.get("ir_override", None)
333336

334337
def all_kwargs(self):
335338
return {
@@ -340,6 +343,7 @@ def all_kwargs(self):
340343
("num_ctas", self.num_ctas),
341344
("num_stages", self.num_stages),
342345
("maxnreg", self.maxnreg),
346+
("ir_override", self.ir_override),
343347
) if v is not None
344348
}
345349
}

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class CUDAOptions:
106106
maxnreg: Optional[int] = None
107107
cluster_dims: tuple = (1, 1, 1)
108108
ptx_version: int = None
109+
ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
109110
enable_fp_fusion: bool = True
110111
launch_cooperative_grid: bool = False
111112
launch_pdl: bool = False

0 commit comments

Comments
 (0)