|
1 | 1 | import gc |
2 | | -# import importlib |
3 | | -# import os |
4 | | -# import sys |
5 | | -# import tempfile |
6 | | -# import textwrap |
7 | | -# import time |
8 | 2 | import tracemalloc |
| 3 | +import pytest |
| 4 | +import pathlib |
| 5 | +import os |
9 | 6 |
|
10 | 7 | import torch |
11 | | - |
12 | 8 | import triton |
13 | 9 | import triton.language as tl |
14 | | - |
15 | | -# from typing import Tuple |
| 10 | +from triton._internal_testing import is_cuda, is_hip |
16 | 11 |
|
17 | 12 |
|
18 | 13 | def test_metadata() -> None: |
@@ -149,3 +144,63 @@ def kernel(x): |
149 | 144 | triton.knobs.runtime.kernel_load_end_hook.remove(hook_end0) |
150 | 145 | triton.knobs.runtime.kernel_load_start_hook.remove(hook_start1) |
151 | 146 | triton.knobs.runtime.kernel_load_end_hook.remove(hook_end1) |
| 147 | + |
| 148 | + |
| 149 | +@pytest.mark.parametrize("options", [ |
| 150 | + {"num_warps": 1}, |
| 151 | + {"enable_fp_fusion": False}, |
| 152 | + {"extern_libs": {}}, |
| 153 | +]) |
| 154 | +def test_launch_with_options(options) -> None: |
| 155 | + if "extern_libs" in options: |
| 156 | + # copied from tutorials/07-extern-functions.py |
| 157 | + current_dir = pathlib.Path(os.path.dirname(os.path.abspath(__file__))) |
| 158 | + if is_cuda(): |
| 159 | + libdir = current_dir.parent.parent.parent.parent / 'third_party/nvidia/backend/lib' |
| 160 | + options["extern_libs"] = {"libdevice": str(libdir / 'libdevice.10.bc')} |
| 161 | + elif is_hip(): |
| 162 | + libdir = current_dir.parent.parent.parent.parent / 'third_party/amd/backend/lib' |
| 163 | + options["extern_libs"] = {"ocml": str(libdir / 'ocml.bc'), "ockl": str(libdir / 'ockl.bc')} |
| 164 | + |
| 165 | + compile_info = {} |
| 166 | + counter = 0 |
| 167 | + |
| 168 | + def compile_info_hook(key, repr, fn, compile, is_manual_warmup, already_compiled): |
| 169 | + nonlocal compile_info |
| 170 | + compile_info = compile |
| 171 | + |
| 172 | + def cache_hook(*args, **kwargs): |
| 173 | + nonlocal counter |
| 174 | + counter += 1 |
| 175 | + |
| 176 | + @triton.jit |
| 177 | + def kernel(x): |
| 178 | + pass |
| 179 | + |
| 180 | + triton.knobs.runtime.jit_post_compile_hook = compile_info_hook |
| 181 | + triton.knobs.runtime.jit_cache_hook = cache_hook |
| 182 | + |
| 183 | + # run first without options |
| 184 | + kernel[(1, 1, 1)](6) |
| 185 | + assert counter == 1 |
| 186 | + |
| 187 | + # run with options, should lead to new compilation |
| 188 | + kernel[(1, 1, 1)](6, **options) |
| 189 | + assert counter == 2 |
| 190 | + |
| 191 | + # run a second time for testing kernel-cache look-up |
| 192 | + kernel[(1, 1, 1)](6, **options) |
| 193 | + assert counter == 2 |
| 194 | + |
| 195 | + # check the options are passed on to compile_info correctly |
| 196 | + option_key, option_val = next(iter(options.items())) |
| 197 | + if option_key == "extern_libs": |
| 198 | + # HIPOptions overwrite the extern_libs option, so we skip the test |
| 199 | + # passing and specializing options still is tested |
| 200 | + if not is_hip(): |
| 201 | + assert compile_info[option_key] == tuple(option_val.items()) |
| 202 | + else: |
| 203 | + assert compile_info[option_key] == option_val |
| 204 | + |
| 205 | + triton.knobs.runtime.jit_post_compile_hook = None |
| 206 | + triton.knobs.runtime.jit_cache_hook = None |
0 commit comments