|
4 | 4 | import triton.language as tl
|
5 | 5 | import pytest
|
6 | 6 |
|
| 7 | +import pathlib |
| 8 | +import uuid |
| 9 | +from triton._internal_testing import is_cuda |
| 10 | + |
7 | 11 |
|
8 | 12 | def do_bench(kernel_call, quantiles, use_cuda_graph=False):
|
9 | 13 | if use_cuda_graph:
|
@@ -169,6 +173,203 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
169 | 173 | assert records['capture_named_args']
|
170 | 174 |
|
171 | 175 |
|
| 176 | +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, |
| 177 | + reason="Requires compute capability >= 9 for NV") |
| 178 | +def test_override_ttir(device): |
| 179 | + N = 1024 |
| 180 | + src = torch.randn(N, device=device) |
| 181 | + dst = torch.empty(N, device=device) |
| 182 | + |
| 183 | + ir_src = r""" |
| 184 | +module { |
| 185 | + tt.func public @_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 186 | + %cst = arith.constant dense<1.000000e+01> : tensor<32xf32> |
| 187 | + %c32_i32 = arith.constant 32 : i32 |
| 188 | + %0 = tt.get_program_id x : i32 |
| 189 | + %1 = arith.muli %0, %c32_i32 : i32 |
| 190 | + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> |
| 191 | + %3 = tt.splat %1 : i32 -> tensor<32xi32> |
| 192 | + %4 = arith.addi %3, %2 : tensor<32xi32> |
| 193 | + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> |
| 194 | + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> |
| 195 | + %7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>> |
| 196 | + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32> |
| 197 | + %9 = tt.load %8, %6 : tensor<32x!tt.ptr<f32>> |
| 198 | + %10 = arith.mulf %9, %cst : tensor<32xf32> |
| 199 | + %11 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>> |
| 200 | + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32> |
| 201 | + tt.store %12, %10, %6 : tensor<32x!tt.ptr<f32>> |
| 202 | + tt.return |
| 203 | + } |
| 204 | +} |
| 205 | + """ |
| 206 | + temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ttir") |
| 207 | + temp_file.write_text(ir_src) |
| 208 | + |
| 209 | + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})] |
| 210 | + |
| 211 | + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) |
| 212 | + @triton.jit |
| 213 | + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): |
| 214 | + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 215 | + x = tl.load(src + offsets, mask=offsets < N) |
| 216 | + tl.store(dst + offsets, x, mask=offsets < N) |
| 217 | + |
| 218 | + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) |
| 219 | + _kernel[grid](dst, src, N=N) |
| 220 | + |
| 221 | + # Change the behavior of kernel by overriding PTX |
| 222 | + torch.testing.assert_close(src * 10, dst) |
| 223 | + |
| 224 | + |
| 225 | +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, |
| 226 | + reason="Requires compute capability >= 9 for NV") |
| 227 | +def test_override_ttgir(device): |
| 228 | + N = 1024 |
| 229 | + src = torch.randn(N, device=device) |
| 230 | + dst = torch.empty(N, device=device) |
| 231 | + |
| 232 | + ir_src = r""" |
| 233 | +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> |
| 234 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { |
| 235 | + tt.func public @_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 236 | + %cst = arith.constant dense<1.000000e+01> : tensor<32xf32, #blocked> |
| 237 | + %c32_i32 = arith.constant 32 : i32 |
| 238 | + %0 = tt.get_program_id x : i32 |
| 239 | + %1 = arith.muli %0, %c32_i32 : i32 |
| 240 | + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked> |
| 241 | + %3 = tt.splat %1 : i32 -> tensor<32xi32, #blocked> |
| 242 | + %4 = arith.addi %3, %2 : tensor<32xi32, #blocked> |
| 243 | + %5 = tt.splat %arg2 : i32 -> tensor<32xi32, #blocked> |
| 244 | + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32, #blocked> |
| 245 | + %7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>, #blocked> |
| 246 | + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr<f32>, #blocked>, tensor<32xi32, #blocked> |
| 247 | + %9 = tt.load %8, %6 : tensor<32x!tt.ptr<f32>, #blocked> |
| 248 | + %10 = arith.mulf %9, %cst : tensor<32xf32, #blocked> |
| 249 | + %11 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>, #blocked> |
| 250 | + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr<f32>, #blocked>, tensor<32xi32, #blocked> |
| 251 | + tt.store %12, %10, %6 : tensor<32x!tt.ptr<f32>, #blocked> |
| 252 | + tt.return |
| 253 | + } |
| 254 | +} |
| 255 | + """ |
| 256 | + temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ttgir") |
| 257 | + temp_file.write_text(ir_src) |
| 258 | + |
| 259 | + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})] |
| 260 | + |
| 261 | + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) |
| 262 | + @triton.jit |
| 263 | + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): |
| 264 | + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 265 | + x = tl.load(src + offsets, mask=offsets < N) |
| 266 | + tl.store(dst + offsets, x, mask=offsets < N) |
| 267 | + |
| 268 | + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) |
| 269 | + _kernel[grid](dst, src, N=N) |
| 270 | + |
| 271 | + # Change the behavior of kernel by overriding PTX |
| 272 | + torch.testing.assert_close(src * 10, dst) |
| 273 | + |
| 274 | + |
| 275 | +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, |
| 276 | + reason="PTX file in this unit test is only for SM90") |
| 277 | +def test_override_ptx(device): |
| 278 | + N = 1024 |
| 279 | + src = torch.randn(N, device=device) |
| 280 | + dst = torch.empty(N, device=device) |
| 281 | + |
| 282 | + ir_src = r""" |
| 283 | +// |
| 284 | +// Generated by LLVM NVPTX Back-End |
| 285 | +// |
| 286 | +
|
| 287 | +.version 8.7 |
| 288 | +.target sm_90a |
| 289 | +.address_size 64 |
| 290 | +
|
| 291 | + // .globl _kernel // -- Begin function _kernel |
| 292 | + // @_kernel |
| 293 | +.visible .entry _kernel( |
| 294 | + .param .u64 .ptr .global .align 1 _kernel_param_0, |
| 295 | + .param .u64 .ptr .global .align 1 _kernel_param_1, |
| 296 | + .param .u32 _kernel_param_2, |
| 297 | + .param .u64 .ptr .global .align 1 _kernel_param_3 |
| 298 | +) |
| 299 | +.reqntid 128 |
| 300 | +{ |
| 301 | + .reg .pred %p<4>; |
| 302 | + .reg .b32 %r<10>; |
| 303 | + .reg .b32 %f<3>; |
| 304 | + .reg .b64 %rd<6>; |
| 305 | + .loc 1 180 0 |
| 306 | +$L__func_begin0: |
| 307 | + .loc 1 180 0 |
| 308 | +
|
| 309 | +// %bb.0: |
| 310 | + ld.param.u64 %rd3, [_kernel_param_0]; |
| 311 | + ld.param.u64 %rd4, [_kernel_param_1]; |
| 312 | +$L__tmp0: |
| 313 | + .loc 1 181 28 |
| 314 | + mov.u32 %r3, %ctaid.x; |
| 315 | + .loc 1 181 33 |
| 316 | + shl.b32 %r4, %r3, 5; |
| 317 | + ld.param.u32 %r5, [_kernel_param_2]; |
| 318 | + .loc 1 181 59 |
| 319 | + mov.u32 %r6, %tid.x; |
| 320 | + and.b32 %r7, %r6, 31; |
| 321 | + .loc 1 181 46 |
| 322 | + or.b32 %r8, %r4, %r7; |
| 323 | + .loc 1 182 46 |
| 324 | + setp.lt.s32 %p1, %r8, %r5; |
| 325 | + .loc 1 182 22 |
| 326 | + mul.wide.s32 %rd5, %r8, 4; |
| 327 | + add.s64 %rd1, %rd4, %rd5; |
| 328 | + .loc 1 182 16 |
| 329 | + // begin inline asm |
| 330 | + mov.u32 %r1, 0x0; |
| 331 | + @%p1 ld.global.b32 { %r1 }, [ %rd1 + 0 ]; |
| 332 | + // end inline asm |
| 333 | + mov.b32 %f1, %r1; |
| 334 | + .loc 1 183 12 |
| 335 | + mul.f32 %f2, %f1, 0f41200000; |
| 336 | + .loc 1 184 19 |
| 337 | + add.s64 %rd2, %rd3, %rd5; |
| 338 | + .loc 1 184 28 |
| 339 | + and.b32 %r9, %r6, 96; |
| 340 | + setp.eq.s32 %p3, %r9, 0; |
| 341 | + mov.b32 %r2, %f2; |
| 342 | + and.pred %p2, %p3, %p1; |
| 343 | + // begin inline asm |
| 344 | + @%p2 st.global.b32 [ %rd2 + 0 ], { %r2 }; |
| 345 | + // end inline asm |
| 346 | + .loc 1 184 4 |
| 347 | + ret; |
| 348 | +$L__tmp1: |
| 349 | +$L__func_end0: |
| 350 | + // -- End function |
| 351 | +} |
| 352 | + """ |
| 353 | + temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ptx") |
| 354 | + temp_file.write_text(ir_src) |
| 355 | + |
| 356 | + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})] |
| 357 | + |
| 358 | + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) |
| 359 | + @triton.jit |
| 360 | + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): |
| 361 | + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 362 | + x = tl.load(src + offsets, mask=offsets < N) |
| 363 | + x = x * 10 |
| 364 | + tl.store(dst + offsets, x, mask=offsets < N) |
| 365 | + |
| 366 | + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) |
| 367 | + _kernel[grid](dst, src, N=N) |
| 368 | + |
| 369 | + # Change the behavior of kernel by overriding PTX |
| 370 | + torch.testing.assert_close(src * 10, dst) |
| 371 | + |
| 372 | + |
172 | 373 | def test_exceed_tmem(device):
|
173 | 374 | if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10:
|
174 | 375 | pytest.skip("Test requires tensor memory.")
|
|
0 commit comments