From 9d89a0ab267add01f8c275474631bbb539cc1833 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Fri, 1 Nov 2024 10:32:52 +0000 Subject: [PATCH] Use pytest' tmp_path fixture instead of tempfile.NamedTemporaryFile Signed-off-by: Anatoly Myachev --- python/test/unit/language/test_core.py | 67 +++--- python/test/unit/runtime/test_cache.py | 21 +- third_party/proton/test/test_api.py | 230 +++++++++--------- third_party/proton/test/test_cmd.py | 32 +-- third_party/proton/test/test_lib.py | 40 ++- third_party/proton/test/test_profile.py | 308 ++++++++++++------------ 6 files changed, 350 insertions(+), 348 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 77f09e6fa1..58709d5416 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5,7 +5,6 @@ from typing import Optional import math import textwrap -import tempfile import numpy as np import pytest @@ -2589,7 +2588,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) -def test_scan_layouts(M, N, src_layout, axis, device): +def test_scan_layouts(M, N, src_layout, axis, device, tmp_path): ir = f""" #blocked = {src_layout} @@ -2622,10 +2621,10 @@ def test_scan_layouts(M, N, src_layout, axis, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_scan_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + rs = RandomState(17) x = rs.randint(-100, 100, (M, N)).astype('int32') @@ -2662,7 +2661,7 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) @pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) -def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device, tmp_path): if isinstance(src_layout, (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") @@ -2756,10 +2755,9 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) @@ -2789,7 +2787,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce @pytest.mark.parametrize("M", [32, 64, 128, 256]) @pytest.mark.parametrize("src_layout", layouts) -def test_store_op(M, src_layout, device): +def test_store_op(M, src_layout, device, tmp_path): ir = f""" #src = {src_layout} @@ -2810,10 +2808,9 @@ def test_store_op(M, src_layout, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - store_kernel = triton.compile(f.name) + temp_file = tmp_path / "test_store_op.ttgir" + temp_file.write_text(ir) + store_kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, 1)).astype('float32') @@ -2840,7 +2837,7 @@ def test_store_op(M, src_layout, device): @pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) @pytest.mark.parametrize("src_dim", [0, 1]) @pytest.mark.parametrize("dst_dim", [0, 1]) -def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path): ir = f""" #dst = {dst_layout} @@ -2860,10 +2857,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert1d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, )).astype('int32') @@ -2901,7 +2897,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("op", ["sum", "max"]) @pytest.mark.parametrize("first_axis", [0, 1]) -def test_chain_reduce(M, N, src_layout, op, device, first_axis): +def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path): op_str = "" if op == "sum": @@ -2942,10 +2938,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_chain_reduce.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') @@ -5260,7 +5255,7 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path): if str(src_layout) == str(dst_layout): pytest.xfail("Do not convert same layout") if is_hip() or is_xpu(): @@ -5329,10 +5324,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x, device=device) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert2d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5385,7 +5380,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("mma_pair", mma_pairs) -def test_convertmma2mma(M, N, mma_pair, dtype, device): +def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path): if is_hip() or is_xpu(): pytest.xfail("test_mma2mma is not supported in HIP/XPU") @@ -5442,10 +5437,10 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convertmma2mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a0084e0be9..9fca21af87 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,7 +1,6 @@ import importlib.util import itertools import shutil -import tempfile import pytest import torch @@ -128,17 +127,15 @@ def test_combine_fn_change(): seen_keys.add(key) -def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: - f.write(('# extra line\n' * num_extra_lines) + code) - f.flush() - spec = importlib.util.spec_from_file_location("module.name", f.name) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) +def write_and_load_module(temp_file, code, num_extra_lines): + temp_file.write_text(('# extra line\n' * num_extra_lines) + code) + spec = importlib.util.spec_from_file_location("module.name", str(temp_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) return module -def test_changed_line_numbers_invalidate_cache(): +def test_changed_line_numbers_invalidate_cache(tmp_path): from textwrap import dedent code = dedent(""" import triton @@ -146,10 +143,12 @@ def test_changed_line_numbers_invalidate_cache(): def test_kernel(i): i = i + 1 """) - orig_mod = write_and_load_module(code, 0) + temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py" + orig_mod = write_and_load_module(temp_file0, code, 0) orig_cache_key = orig_mod.test_kernel.cache_key - updated_mod = write_and_load_module(code, 1) + temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py" + updated_mod = write_and_load_module(temp_file1, code, 1) updated_cache_key = updated_mod.test_kernel.cache_key assert orig_cache_key != updated_cache_key diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index 713572c4fc..3244d3b701 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -1,23 +1,24 @@ import json import triton.profiler as proton -import tempfile import pathlib -def test_profile(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id0 = proton.start(f.name.split(".")[0]) - proton.activate() - proton.deactivate() - proton.finalize() - assert session_id0 == 0 +def test_profile(tmp_path): + temp_file0 = tmp_path / "test_profile0.hatchet" + session_id0 = proton.start(str(temp_file0)) + proton.activate() + proton.deactivate() + proton.finalize() + assert session_id0 == 0 + assert temp_file0.exists() - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id1 = proton.start(f.name.split(".")[0]) - proton.activate(session_id1) - proton.deactivate(session_id1) - proton.finalize(session_id1) - assert session_id1 == session_id0 + 1 + temp_file1 = tmp_path / "test_profile1.hatchet" + session_id1 = proton.start(str(temp_file1)) + proton.activate(session_id1) + proton.deactivate(session_id1) + proton.finalize(session_id1) + assert session_id1 == session_id0 + 1 + assert temp_file1.exists() session_id2 = proton.start("test") proton.activate(session_id2) @@ -28,19 +29,16 @@ def test_profile(): pathlib.Path("test.hatchet").unlink() -def test_profile_decorator(): - f = tempfile.NamedTemporaryFile(delete=True) - name = f.name.split(".")[0] +def test_profile_decorator(tmp_path): + temp_file = tmp_path / "test_profile_decorator.hatchet" - @proton.profile(name=name) + @proton.profile(name=str(temp_file)) def foo0(a, b): return a + b foo0(1, 2) proton.finalize() - assert pathlib.Path(f.name).exists() - - f.close() + assert temp_file.exists() @proton.profile def foo1(a, b): @@ -48,126 +46,130 @@ def foo1(a, b): foo1(1, 2) proton.finalize() - assert pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet").exists() + default_file = pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet") + assert default_file.exists() + default_file.unlink() -def test_scope(): +def test_scope(tmp_path): # Scope can be annotated even when profiling is off with proton.scope("test"): pass - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test"): - pass + temp_file = tmp_path / "test_scope.hatchet" + proton.start(str(temp_file)) + with proton.scope("test"): + pass - @proton.scope("test") - def foo(): - pass + @proton.scope("test") + def foo(): + pass - foo() + foo() - proton.enter_scope("test") - proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.enter_scope("test") + proton.exit_scope() + proton.finalize() + assert temp_file.exists() -def test_hook(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id0 = proton.start(f.name.split(".")[0], hook="triton") - proton.activate(session_id0) - proton.deactivate(session_id0) - proton.finalize(None) - assert pathlib.Path(f.name).exists() +def test_hook(tmp_path): + temp_file = tmp_path / "test_hook.hatchet" + session_id0 = proton.start(str(temp_file), hook="triton") + proton.activate(session_id0) + proton.deactivate(session_id0) + proton.finalize(None) + assert temp_file.exists() -def test_scope_metrics(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0]) - # Test different scope creation methods - with proton.scope("test0", {"a": 1.0}): - pass +def test_scope_metrics(tmp_path): + temp_file = tmp_path / "test_scope_metrics.hatchet" + session_id = proton.start(str(temp_file)) + # Test different scope creation methods + with proton.scope("test0", {"a": 1.0}): + pass - @proton.scope("test1", {"a": 1.0}) - def foo(): - pass + @proton.scope("test1", {"a": 1.0}) + def foo(): + pass - foo() + foo() - # After deactivation, the metrics should be ignored - proton.deactivate(session_id) - proton.enter_scope("test2", metrics={"a": 1.0}) - proton.exit_scope() + # After deactivation, the metrics should be ignored + proton.deactivate(session_id) + proton.enter_scope("test2", metrics={"a": 1.0}) + proton.exit_scope() - # Metrics should be recorded again after reactivation - proton.activate(session_id) - proton.enter_scope("test3", metrics={"a": 1.0}) - proton.exit_scope() + # Metrics should be recorded again after reactivation + proton.activate(session_id) + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() - proton.enter_scope("test3", metrics={"a": 1.0}) - proton.exit_scope() + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.finalize() + assert temp_file.exists() + with open(str(temp_file)) as f: data = json.load(f) - assert len(data[0]["children"]) == 3 - for child in data[0]["children"]: - if child["frame"]["name"] == "test3": - assert child["metrics"]["a"] == 2.0 - - -def test_scope_properties(): - with open("test.hatchet", "w+") as f: - proton.start(f.name.split(".")[0]) - # Test different scope creation methods - # Different from metrics, properties could be str - with proton.scope("test0", properties={"a": "1"}): - pass + assert len(data[0]["children"]) == 3 + for child in data[0]["children"]: + if child["frame"]["name"] == "test3": + assert child["metrics"]["a"] == 2.0 + + +def test_scope_properties(tmp_path): + temp_file = tmp_path / "test.hatchet" + proton.start(str(temp_file)) + # Test different scope creation methods + # Different from metrics, properties could be str + with proton.scope("test0", properties={"a": "1"}): + pass - @proton.scope("test1", properties={"a": "1"}) - def foo(): - pass + @proton.scope("test1", properties={"a": "1"}) + def foo(): + pass - foo() + foo() - # Properties do not aggregate - proton.enter_scope("test2", properties={"a": 1.0}) - proton.exit_scope() + # Properties do not aggregate + proton.enter_scope("test2", properties={"a": 1.0}) + proton.exit_scope() - proton.enter_scope("test2", properties={"a": 1.0}) - proton.exit_scope() + proton.enter_scope("test2", properties={"a": 1.0}) + proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.finalize() + assert temp_file.exists() + with open(str(temp_file)) as f: data = json.load(f) - for child in data[0]["children"]: - if child["frame"]["name"] == "test2": - assert child["metrics"]["a"] == 1.0 - elif child["frame"]["name"] == "test0": - assert child["metrics"]["a"] == "1" + for child in data[0]["children"]: + if child["frame"]["name"] == "test2": + assert child["metrics"]["a"] == 1.0 + elif child["frame"]["name"] == "test0": + assert child["metrics"]["a"] == "1" -def test_throw(): +def test_throw(tmp_path): # Catch an exception thrown by c++ session_id = 100 - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - activate_error = "" - try: - session_id = proton.start(f.name.split(".")[0]) - proton.activate(session_id + 1) - except Exception as e: - activate_error = str(e) - finally: - proton.finalize() - assert "Session has not been initialized: " + str(session_id + 1) in activate_error - - deactivate_error = "" - try: - session_id = proton.start(f.name.split(".")[0]) - proton.deactivate(session_id + 1) - except Exception as e: - deactivate_error = str(e) - finally: - proton.finalize() - assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error + temp_file = tmp_path / "test_throw.hatchet" + activate_error = "" + try: + session_id = proton.start(str(temp_file)) + proton.activate(session_id + 1) + except Exception as e: + activate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in activate_error + + deactivate_error = "" + try: + session_id = proton.start(str(temp_file)) + proton.deactivate(session_id + 1) + except Exception as e: + deactivate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index fa3331c024..62ca2ca155 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -1,6 +1,5 @@ import pytest import subprocess -import tempfile import json @@ -11,21 +10,22 @@ def test_help(): @pytest.mark.parametrize("mode", ["script", "python", "pytest"]) -def test_exec(mode): +def test_exec(mode, tmp_path): file_path = __file__ helper_file = file_path.replace("test_cmd.py", "helper.py") - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - name = f.name.split(".")[0] - if mode == "script": - ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) - elif mode == "python": - ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], - stdout=subprocess.DEVNULL) - elif mode == "pytest": - ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], - stdout=subprocess.DEVNULL) - assert ret == 0 + temp_file = tmp_path / "test_exec.hatchet" + name = str(temp_file) + if mode == "script": + ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + elif mode == "python": + ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) + elif mode == "pytest": + ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) + assert ret == 0 + with open(name) as f: data = json.load(f, ) - kernels = data[0]["children"] - assert len(kernels) == 2 - assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" + kernels = data[0]["children"] + assert len(kernels) == 2 + assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" diff --git a/third_party/proton/test/test_lib.py b/third_party/proton/test/test_lib.py index 0380268c04..9931480e38 100644 --- a/third_party/proton/test/test_lib.py +++ b/third_party/proton/test/test_lib.py @@ -1,6 +1,4 @@ import triton._C.libproton.proton as libproton -import tempfile -import pathlib from triton.profiler.profile import _select_backend @@ -25,22 +23,22 @@ def test_op(): libproton.exit_op(id0, "zero") -def test_session(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = libproton.start(f.name.split(".")[0], "shadow", "tree", _select_backend()) - libproton.deactivate(session_id) - libproton.activate(session_id) - libproton.finalize(session_id, "hatchet") - libproton.finalize_all("hatchet") - assert pathlib.Path(f.name).exists() - - -def test_add_metrics(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - libproton.start(f.name.split(".")[0], "shadow", "tree", _select_backend()) - id1 = libproton.record_scope() - libproton.enter_scope(id1, "one") - libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) - libproton.exit_scope(id1, "one") - libproton.finalize_all("hatchet") - assert pathlib.Path(f.name).exists() +def test_session(tmp_path): + temp_file = tmp_path / "test_session.hatchet" + session_id = libproton.start(str(temp_file), "shadow", "tree", _select_backend()) + libproton.deactivate(session_id) + libproton.activate(session_id) + libproton.finalize(session_id, "hatchet") + libproton.finalize_all("hatchet") + assert temp_file.exists() + + +def test_add_metrics(tmp_path): + temp_file = tmp_path / "test_add_metrics.hatchet" + libproton.start(str(temp_file), "shadow", "tree", _select_backend()) + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) + libproton.exit_scope(id1, "one") + libproton.finalize_all("hatchet") + assert temp_file.exists() diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 13cb9bd99c..2ef748d886 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -1,7 +1,6 @@ import torch import triton import triton.profiler as proton -import tempfile import json import pytest from typing import NamedTuple @@ -14,30 +13,31 @@ def is_hip(): @pytest.mark.parametrize("context", ["shadow", "python"]) -def test_torch(context): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], context=context) - proton.enter_scope("test") - torch.ones((2, 2), device="cuda") - proton.exit_scope() - proton.finalize() +def test_torch(context, tmp_path): + temp_file = tmp_path / "test_torch.hatchet" + proton.start(str(temp_file), context=context) + proton.enter_scope("test") + torch.ones((2, 2), device="cuda") + proton.exit_scope() + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - if context == "shadow": - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test" - assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 - elif context == "python": - assert len(data[0]["children"]) == 1 - # The last frame is the torch kernel - prev_frame = data - curr_frame = data[0]["children"] - while len(curr_frame) > 0: - prev_frame = curr_frame - curr_frame = curr_frame[0]["children"] - assert "elementwise_kernel" in prev_frame[0]["frame"]["name"] - - -def test_triton(): + if context == "shadow": + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test" + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + elif context == "python": + assert len(data[0]["children"]) == 1 + # The last frame is the torch kernel + prev_frame = data + curr_frame = data[0]["children"] + while len(curr_frame) > 0: + prev_frame = curr_frame + curr_frame = curr_frame[0]["children"] + assert "elementwise_kernel" in prev_frame[0]["frame"]["name"] + + +def test_triton(tmp_path): @triton.jit def foo(x, y): @@ -45,23 +45,24 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test0"): - with proton.scope("test1"): - foo[(1, )](x, y) - with proton.scope("test2"): + temp_file = tmp_path / "test_triton.hatchet" + proton.start(str(temp_file)) + with proton.scope("test0"): + with proton.scope("test1"): foo[(1, )](x, y) - proton.finalize() + with proton.scope("test2"): + foo[(1, )](x, y) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - assert len(data[0]["children"]) == 2 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert len(data[0]["children"][0]["children"]) == 1 - assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" - assert data[0]["children"][1]["frame"]["name"] == "test2" + assert len(data[0]["children"]) == 2 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert len(data[0]["children"][0]["children"]) == 1 + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" + assert data[0]["children"][1]["frame"]["name"] == "test2" -def test_cudagraph(): +def test_cudagraph(tmp_path): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) @@ -75,46 +76,47 @@ def fn(): c = a + b foo[(1, )](a, b, c) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], context="shadow") + temp_file = tmp_path / "test_cudagraph.hatchet" + proton.start(str(temp_file), context="shadow") - # warmup - # four kernels - fn() + # warmup + # four kernels + fn() - # no kernels - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(10): - fn() + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn() - proton.enter_scope("test") - g.replay() - g.reset() - torch.cuda.synchronize() - proton.exit_scope() - proton.finalize() + proton.enter_scope("test") + g.replay() + g.reset() + torch.cuda.synchronize() + proton.exit_scope() + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - # CUDA/HIP graph may also invoke additional kernels to reset outputs - # {torch.ones, add, foo, test} - assert len(data[0]["children"]) >= 4 - # find the test frame - test_frame = None - for child in data[0]["children"]: - if child["frame"]["name"] == "test": - test_frame = child - break - assert test_frame is not None - # {torch.ones, add, foo} - if is_hip(): - assert len(test_frame["children"]) >= 2 - else: - assert len(test_frame["children"]) >= 3 - assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 - - -def test_metrics(): + # CUDA/HIP graph may also invoke additional kernels to reset outputs + # {torch.ones, add, foo, test} + assert len(data[0]["children"]) >= 4 + # find the test frame + test_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test": + test_frame = child + break + assert test_frame is not None + # {torch.ones, add, foo} + if is_hip(): + assert len(test_frame["children"]) >= 2 + else: + assert len(test_frame["children"]) >= 3 + assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_metrics(tmp_path): @triton.jit def foo(x, y): @@ -122,18 +124,19 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test0", {"foo": 1.0}): - foo[(1, )](x, y) - proton.finalize() + temp_file = tmp_path / "test_metrics.hatchet" + proton.start(str(temp_file)) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert data[0]["children"][0]["metrics"]["foo"] == 1.0 + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["metrics"]["foo"] == 1.0 -def test_metrics_ignore(): +def test_metrics_ignore(tmp_path): @triton.jit def foo(x, y): @@ -141,36 +144,38 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0]) - proton.deactivate(session_id) - with proton.scope("test0", {"foo": 1.0}): - foo[(1, )](x, y) - proton.activate(session_id) - proton.finalize() + temp_file = tmp_path / "test_metrics_ignore.hatchet" + session_id = proton.start(str(temp_file)) + proton.deactivate(session_id) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.activate(session_id) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - assert len(data[0]["children"]) == 0 - - -def test_scope_backward(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("ones1"): - a = torch.ones((100, 100), device="cuda", requires_grad=True) - with proton.scope("plus"): - a2 = a * a * a - with proton.scope("ones2"): - loss = torch.ones_like(a2) - - # Backward triggers two kernels in a single scope - with proton.scope("backward"): - a2.backward(loss) - proton.finalize() + assert len(data[0]["children"]) == 0 + + +def test_scope_backward(tmp_path): + temp_file = tmp_path / "test_scope_backward.hatchet" + proton.start(str(temp_file)) + with proton.scope("ones1"): + a = torch.ones((100, 100), device="cuda", requires_grad=True) + with proton.scope("plus"): + a2 = a * a * a + with proton.scope("ones2"): + loss = torch.ones_like(a2) + + # Backward triggers two kernels in a single scope + with proton.scope("backward"): + a2.backward(loss) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - assert len(data[0]["children"]) == 4 + assert len(data[0]["children"]) == 4 -def test_hook(): +def test_hook(tmp_path): def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): # get arg's element size @@ -187,20 +192,21 @@ def foo(x, size: tl.constexpr, y): x = torch.tensor([2], device="cuda", dtype=torch.float32) y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], hook="triton") - with proton.scope("test0"): - foo[(1, )](x, 1, y, num_warps=4) - proton.finalize() + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file), hook="triton") + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" - assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 - assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" + assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 -def test_pcsampling(): +def test_pcsampling(tmp_path): if is_hip(): pytest.skip("HIP backend does not support pc sampling") @@ -214,37 +220,39 @@ def foo(x, y, size: tl.constexpr): for _ in range(1000): tl.store(y + offs, tl.load(x + offs)) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], hook="triton", backend="cupti_pcsampling") - with proton.scope("init"): - x = torch.ones((1024, ), device="cuda", dtype=torch.float32) - y = torch.zeros_like(x) - with proton.scope("test"): - foo[(1, )](x, y, x.size()[0], num_warps=4) - proton.finalize() + temp_file = tmp_path / "test_pcsampling.hatchet" + proton.start(str(temp_file), hook="triton", backend="cupti_pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - init_frame = data[0]["children"][0] - test_frame = data[0]["children"][1] - # With line mapping - assert "foo" in test_frame["children"][0]["frame"]["name"] - assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 - assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] - # Without line mapping - assert "elementwise" in init_frame["children"][0]["frame"]["name"] - assert init_frame["children"][0]["metrics"]["num_samples"] > 0 - - -def test_deactivate(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0], hook="triton") - proton.deactivate(session_id) - torch.randn((10, 10), device="cuda") - proton.activate(session_id) - torch.zeros((10, 10), device="cuda") - proton.deactivate(session_id) - proton.finalize() + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 + + +def test_deactivate(tmp_path): + temp_file = tmp_path / "test_deactivate.hatchet" + session_id = proton.start(str(temp_file), hook="triton") + proton.deactivate(session_id) + torch.randn((10, 10), device="cuda") + proton.activate(session_id) + torch.zeros((10, 10), device="cuda") + proton.deactivate(session_id) + proton.finalize() + with open(str(temp_file)) as f: data = json.load(f) - # Root shouldn't have device id - assert "device_id" not in data[0]["metrics"] - assert len(data[0]["children"]) == 1 - assert "device_id" in data[0]["children"][0]["metrics"] + # Root shouldn't have device id + assert "device_id" not in data[0]["metrics"] + assert len(data[0]["children"]) == 1 + assert "device_id" in data[0]["children"][0]["metrics"]