@@ -164,6 +164,8 @@ def inc_counter(*args, **kwargs):
164164 for i in range (10 ):
165165 kernel [(1 , )](x , 1 , BLOCK = 1024 )
166166 assert counter == 1
167+ device = getattr (torch , device ).current_device ()
168+ kernel .device_caches [device ][0 ].clear ()
167169
168170
169171@pytest .mark .parametrize ('mode' , ['enable' , 'disable' , 'disable_on_alignment' ])
@@ -181,6 +183,10 @@ def inc_counter(*args, **kwargs):
181183 for i in [1 , 2 , 4 , 8 , 16 , 32 ]:
182184 function [(1 , )](x , i , BLOCK = 512 )
183185 assert counter == target
186+ device = getattr (torch , device ).current_device ()
187+ kernel .device_caches [device ][0 ].clear ()
188+ kernel_nospec .device_caches [device ][0 ].clear ()
189+ kernel_nospec_on_alignment .device_caches [device ][0 ].clear ()
184190
185191
186192def test_annotation (device ):
@@ -489,7 +495,7 @@ def cache_hook(*args, **kwargs):
489495 assert specialization_data is not None
490496
491497 # clear the cache
492- shutil .rmtree (fresh_triton_cache , ignore_errors = True )
498+ shutil .rmtree (fresh_triton_cache )
493499 kernel_add .device_caches [device ][0 ].clear ()
494500
495501 # preload the kernel
0 commit comments