|
1 | 1 | # fmt: off |
2 | 2 |
|
3 | 3 |
|
4 | | -import os |
5 | 4 | import numpy as np |
6 | 5 | import torch |
7 | 6 | import pytest |
8 | 7 | import triton |
9 | 8 | import triton.language as tl |
10 | 9 |
|
11 | | -def is_interpreter(): |
12 | | - return os.environ.get('TRITON_INTERPRET', '0') == '1' |
| 10 | +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 |
13 | 11 |
|
14 | | -def is_cuda(): |
15 | | - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" |
16 | | - |
17 | | -def is_hip(): |
18 | | - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" |
19 | | - |
20 | | -def is_on_mi300(): |
21 | | - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') |
22 | 12 |
|
23 | 13 | def matching_int(dtype): |
24 | 14 | if dtype.primitive_bitwidth == 8: |
@@ -283,7 +273,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia |
283 | 273 | def test_typeconvert_upcast(src_dtype, dst_dtype, device): |
284 | 274 | if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) |
285 | 275 | or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip()) |
286 | | - or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))): |
| 276 | + or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))): |
287 | 277 | # If the dtype should error out in the given device, we assert that and return |
288 | 278 | with pytest.raises(triton.CompilationError, match="not supported in this architecture"): |
289 | 279 | launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) |
@@ -334,7 +324,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): |
334 | 324 | if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): |
335 | 325 | pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") |
336 | 326 |
|
337 | | - if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): |
| 327 | + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()): |
338 | 328 | pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") |
339 | 329 |
|
340 | 330 | # dtype : (exponent_bits, mantissa_bits, exponent_bias) |
|
0 commit comments