@@ -53,32 +53,36 @@ def cache_hook(*args, **kwargs):
53
53
fn_name = kwargs ["fn" ].name
54
54
module_name = kwargs ["fn" ].module
55
55
56
- triton .knobs .runtime .jit_cache_hook = cache_hook
57
- o = torch .empty ((1 , ), dtype = torch .float32 , device = device )
58
- k = specialized_kernel [(1 , )](o , )
59
- hash = k .hash
60
- assert o .item () == 1.0
61
- assert module_name == "tests.test_specialize"
62
- assert fn_name == "cacheable_kernel"
63
-
64
- compile_count = 0
65
-
66
- def count_hook (* args , ** kwargs ):
67
- nonlocal compile_count
68
- compile_count += 1
69
-
70
- triton .knobs .runtime .jit_cache_hook = count_hook
71
- # clear the cache
72
- specialized_kernel .device_caches .clear ()
73
-
74
- # retrieve the kernel from name and preload it.
75
- fn = retrieve_fn (module_name , fn_name )
76
- assert fn == specialized_kernel
77
- preload = fn .preload (specialization_data )
78
- assert compile_count == 1
79
- assert preload .hash == hash
80
-
81
- # verify that we hit the cache.
82
- compile_count = 0
83
- specialized_kernel [(1 , )](o , )
84
- assert compile_count == 0
56
+ prev_hook = triton .knobs .runtime .jit_cache_hook
57
+ try :
58
+ triton .knobs .runtime .jit_cache_hook = cache_hook
59
+ o = torch .empty ((1 , ), dtype = torch .float32 , device = device )
60
+ k = specialized_kernel [(1 , )](o , )
61
+ hash = k .hash
62
+ assert o .item () == 1.0
63
+ assert module_name == "tests.test_specialize"
64
+ assert fn_name == "cacheable_kernel"
65
+
66
+ compile_count = 0
67
+
68
+ def count_hook (* args , ** kwargs ):
69
+ nonlocal compile_count
70
+ compile_count += 1
71
+
72
+ triton .knobs .runtime .jit_cache_hook = count_hook
73
+ # clear the cache
74
+ specialized_kernel .device_caches .clear ()
75
+
76
+ # retrieve the kernel from name and preload it.
77
+ fn = retrieve_fn (module_name , fn_name )
78
+ assert fn == specialized_kernel
79
+ preload = fn .preload (specialization_data )
80
+ assert compile_count == 1
81
+ assert preload .hash == hash
82
+
83
+ # verify that we hit the cache.
84
+ compile_count = 0
85
+ specialized_kernel [(1 , )](o , )
86
+ assert compile_count == 0
87
+ finally :
88
+ triton .knobs .runtime .jit_cache_hook = prev_hook
0 commit comments