@@ -199,7 +199,7 @@ def kernel(X, i: tl.int32):
199199 kernel [(1 , )](x , 8 )
200200 kernel [(1 , )](x , 16 )
201201 kernel [(1 , )](x , 17 )
202- assert len (kernel .cache [device ]) == 3
202+ assert len (kernel .device_caches [device ][ 0 ]) == 3
203203
204204
205205GLOBAL_DEFAULT_ARG = 1
@@ -223,7 +223,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
223223 assert x == torch .ones_like (x )
224224
225225 device = getattr (torch , device ).current_device ()
226- assert len (kernel .cache [device ]) == 1
226+ assert len (kernel .device_caches [device ][ 0 ]) == 1
227227
228228
229229GLOBAL_VAR : tl .constexpr = 1
@@ -416,13 +416,13 @@ def kernel_add(a, b, o, N: tl.constexpr):
416416 32 ,
417417 ]
418418 device = getattr (torch , device ).current_device ()
419- assert len (kernel_add .cache [device ]) == 0
419+ assert len (kernel_add .device_caches [device ][ 0 ]) == 0
420420 kernel_add .warmup (torch .float32 , torch .float32 , torch .float32 , 32 , grid = (1 , ))
421- assert len (kernel_add .cache [device ]) == 1
421+ assert len (kernel_add .device_caches [device ][ 0 ]) == 1
422422 kernel_add .warmup (* args , grid = (1 , ))
423- assert len (kernel_add .cache [device ]) == 1
423+ assert len (kernel_add .device_caches [device ][ 0 ]) == 1
424424 kernel_add .warmup (* args , grid = (1 , ))
425- assert len (kernel_add .cache [device ]) == 1
425+ assert len (kernel_add .device_caches [device ][ 0 ]) == 1
426426
427427
428428def test_jit_debug (device ) -> None :
@@ -433,12 +433,12 @@ def kernel(tmp):
433433
434434 device = getattr (torch , device ).current_device ()
435435 tmp = torch .tensor ([1 ], dtype = torch .int32 , device = device )
436- assert len (kernel .cache [device ]) == 0
436+ assert len (kernel .device_caches [device ][ 0 ]) == 0
437437 kernel [(1 , )](tmp , debug = False )
438- assert len (kernel .cache [device ]) == 1
438+ assert len (kernel .device_caches [device ][ 0 ]) == 1
439439 kernel [(1 , )](tmp , debug = True )
440- assert len (kernel .cache [device ]) == 2
441- bins = list (kernel .cache [device ].values ())
440+ assert len (kernel .device_caches [device ][ 0 ]) == 2
441+ bins = list (kernel .device_caches [device ][ 0 ].values ())
442442 assert bins [0 ].asm ['ttir' ] != bins [1 ].asm ['ttir' ]
443443
444444
@@ -455,18 +455,18 @@ def kernel_add_device(a, b, o, N: tl.constexpr):
455455 add_fn (a , b , o , N )
456456
457457 device = getattr (torch , device ).current_device ()
458- assert len (kernel_add_device .cache [device ]) == 0
458+ assert len (kernel_add_device .device_caches [device ][ 0 ]) == 0
459459 kernel_add_device .warmup (torch .float32 , torch .float32 , torch .float32 , 32 , grid = (1 , ))
460- assert len (kernel_add_device .cache [device ]) == 1
461- bins = list (kernel_add_device .cache [device ].values ())
460+ assert len (kernel_add_device .device_caches [device ][ 0 ]) == 1
461+ bins = list (kernel_add_device .device_caches [device ][ 0 ].values ())
462462 inline_ttir = bins [0 ].asm ['ttir' ]
463463 add_fn .noinline = True
464464 add_fn .hash = None
465465 kernel_add_device .hash = None
466- kernel_add_device .cache [device ].clear ()
466+ kernel_add_device .device_caches [device ][ 0 ].clear ()
467467 kernel_add_device .warmup (torch .float32 , torch .float32 , torch .float32 , 32 , grid = (1 , ))
468- assert len (kernel_add_device .cache [device ]) == 1
469- bins = list (kernel_add_device .cache [device ].values ())
468+ assert len (kernel_add_device .device_caches [device ][ 0 ]) == 1
469+ bins = list (kernel_add_device .device_caches [device ][ 0 ].values ())
470470 noinline_ttir = bins [0 ].asm ['ttir' ]
471471 assert inline_ttir != noinline_ttir
472472
@@ -514,12 +514,12 @@ def cache_hook(*args, **kwargs):
514514
515515 # clear the cache
516516 shutil .rmtree (fresh_triton_cache )
517- kernel_add .cache [device ].clear ()
517+ kernel_add .device_caches [device ][ 0 ].clear ()
518518
519519 # preload the kernel
520520 kernel_preload = kernel_add .preload (specialization_data )
521521 assert kernel_preload .hash == hash
522- assert len (kernel_add .cache [device ]) == 1
522+ assert len (kernel_add .device_caches [device ][ 0 ]) == 1
523523
524524 # we should hit the cache and not compile anything
525525 counter = 0
@@ -532,7 +532,7 @@ def inc_counter(*args, **kwargs):
532532 final_kernel = kernel_add .warmup (torch .float32 , torch .float32 , torch .float32 , 32 , tl .float32 , grid = (1 , ))
533533 JITFunction .cache_hook = None
534534 assert counter == 0
535- assert len (kernel_add .cache [device ]) == 1
535+ assert len (kernel_add .device_caches [device ][ 0 ]) == 1
536536 assert final_kernel .hash == hash
537537
538538 # test that we can't preload a mismatched kernel
@@ -572,7 +572,7 @@ def compiled_hook(*args, **kwargs):
572572 kernel_add .warmup (torch .float32 , torch .float32 , torch .float32 , 32 , tl .float32 , grid = (1 , ))
573573 assert specialization_data is not None and specialization_data_compiled == specialization_data
574574 assert is_warmup is True
575- assert key in kernel_add .cache [getattr (torch , device ).current_device ()]
575+ assert key in kernel_add .device_caches [getattr (torch , device ).current_device ()][ 0 ]
576576
577577
578578@pytest .mark .skipif (reason = "within_2g is a HIP specific optimization" , condition = not is_hip ())
0 commit comments