Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions python/test/regression/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import pytest
import tempfile


def pytest_addoption(parser):
Expand All @@ -14,9 +12,7 @@ def device(request):

@pytest.fixture
def fresh_triton_cache():
with tempfile.TemporaryDirectory() as tmpdir:
try:
os.environ["TRITON_CACHE_DIR"] = tmpdir
yield tmpdir
finally:
os.environ.pop("TRITON_CACHE_DIR", None)
from triton import knobs
with knobs.compilation.scope():
knobs.compilation.always_compile = True
yield
24 changes: 18 additions & 6 deletions python/test/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import sys
import pathlib
import pytest
import tempfile
from typing import Optional, Set
import contextlib


def pytest_configure(config):
Expand Down Expand Up @@ -69,11 +69,23 @@ def device(request):

@pytest.fixture
def fresh_triton_cache():
with tempfile.TemporaryDirectory() as tmpdir:
from triton import knobs
with knobs.cache.scope():
knobs.cache.dir = tmpdir
yield tmpdir
from triton import knobs
with knobs.compilation.scope():
knobs.compilation.always_compile = True
yield


@pytest.fixture
def fresh_triton_cache_scope():
from triton import knobs

@contextlib.contextmanager
def fresh_cache():
with knobs.compilation.scope():
knobs.compilation.always_compile = True
yield

yield fresh_cache


def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
Expand Down
2 changes: 0 additions & 2 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import importlib.util
import itertools
import os
import shutil
import pathlib

import pytest
Expand Down Expand Up @@ -495,7 +494,6 @@ def cache_hook(*args, **kwargs):
assert specialization_data is not None

# clear the cache
shutil.rmtree(fresh_triton_cache)
kernel_add.device_caches[device][0].clear()

# preload the kernel
Expand Down
7 changes: 4 additions & 3 deletions python/test/unit/runtime/test_compilation_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from triton.knobs import CompileTimes
from triton.compiler.compiler import ASTSource, IRSource

from typing import Any, Union
from typing import Any, Union, Callable

import torch

Expand All @@ -17,7 +17,7 @@ def cumsum_kernel(ptr):
tl.store(block, tl.cumsum(x, 0))


def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache: str) -> None:
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache_scope: Callable) -> None:
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], dict[str, Any], CompileTimes, bool], None] = None

def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str], metadata_group: dict[str, Any],
Expand All @@ -29,7 +29,8 @@ def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str],
fresh_knobs_except_libraries.compilation.listener = compile_listener

x = torch.randn(4, device=device)
cumsum_kernel[(1, )](x)
with fresh_triton_cache_scope():
cumsum_kernel[(1, )](x)

assert captured is not None

Expand Down
2 changes: 0 additions & 2 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import multiprocessing
import shutil

import triton
import triton.language as tl
Expand Down Expand Up @@ -87,7 +86,6 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
compile_empty_kernel_with_gc()

# stage 2.p
shutil.rmtree(fresh_triton_cache)
mp_ctx = multiprocessing.get_context(start_method)
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc)

Expand Down
12 changes: 4 additions & 8 deletions third_party/intel/python/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import pytest
import tempfile


def pytest_addoption(parser):
Expand All @@ -14,9 +12,7 @@ def device(request):

@pytest.fixture
def fresh_triton_cache():
with tempfile.TemporaryDirectory() as tmpdir:
try:
os.environ["TRITON_CACHE_DIR"] = tmpdir
yield tmpdir
finally:
os.environ.pop("TRITON_CACHE_DIR", None)
from triton import knobs
with knobs.compilation.scope():
knobs.compilation.always_compile = True
yield
Loading