1+ import itertools
12import os
23import tempfile
3- from typing import Dict , List
4+ from typing import Any , List
45
56import torch
67
@@ -327,48 +328,63 @@ def test_multiple_dynamic_shapes_cache():
327328
328329
329330class GemmRunnerComplexTuningConfigs (TunableRunner ):
331+
332+ # test serialization of different types of tactics
330333 valid_tactic_ids = [- 1 , 0 , 1 ]
334+ valid_tile_sizes = [(128 , 128 ), (256 , 256 )]
335+ valid_cluster_sizes = [[1 , 1 , 1 ], [2 , 2 , 1 ]]
336+
331337 tune_max_num_tokens = 32
332338
333339 def get_valid_tactics (
334340 self ,
335341 inputs : List [FakeTensor ],
336342 profile : OptimizationProfile ,
337343 ** kwargs ,
338- ) -> List [Dict [ str , int ] ]:
344+ ) -> List [Any ]:
339345 # During the tuning process, we verify if the tuning config behaves as expected
340-
341346 assert inputs [0 ].shape [0 ] <= self .tune_max_num_tokens , \
342347 f"Input shape { inputs [0 ].shape [0 ]} is larger than the max num tokens { self .tune_max_num_tokens } "
343348
344349 assert inputs [0 ][- 1 , 0 ] == inputs [0 ].shape [0 ], \
345350 f"Input shape { inputs [0 ].shape [0 ]} is not set through the pre_hook correctly"
346351
347- # The simulated delay is not deterministic, so we need to return specific tactics here
348352 return [{
349- "block_size" : block_size ,
350- "tactic_id" : tactic_id
351- } for tactic_id in self .valid_tactic_ids for block_size in [128 , 256 ]]
353+ "int_tactic_id" : tactic_id ,
354+ "tuple_tile_size" : tile_size ,
355+ "list_cluster_size" : cluster_size ,
356+ } for tactic_id , tile_size , cluster_size in itertools .product (
357+ self .valid_tactic_ids ,
358+ self .valid_tile_sizes ,
359+ self .valid_cluster_sizes ,
360+ )]
352361
353362 def forward (
354363 self ,
355364 / ,
356365 inputs : List [torch .Tensor ],
357366 * ,
358- tactic : dict = {} ,
367+ tactic : Any = - 1 ,
359368 ) -> torch .Tensor :
360369 # Notice that in fallback case tactic is -1
361370 if tactic == - 1 :
362371 # assign default configs for fallback case
363- block_size , tactic_id = 128 , - 1
372+ tactic_id , tile_size , cluster_size = - 1 , ( 128 , 256 ), [ 1 , 1 , 1 ]
364373 else :
365- block_size , tactic_id = tactic ["block_size" ], tactic ["tactic_id" ]
366- assert tactic_id in self .valid_tactic_ids
374+ tactic_id , tile_size , cluster_size = tactic [
375+ "int_tactic_id" ], tactic ["tuple_tile_size" ], tactic [
376+ "list_cluster_size" ]
377+
378+ assert isinstance (tactic_id , int ) and tactic_id in self .valid_tactic_ids
379+ assert isinstance (tile_size , tuple ) and len (tile_size ) == 2 \
380+ and tile_size in self .valid_tile_sizes
381+ assert isinstance (cluster_size , list ) and len (cluster_size ) == 3 \
382+ and cluster_size in self .valid_cluster_sizes
367383 return [gemm_0 , gemm_1 , gemm_fallback ][tactic_id ](* inputs )
368384
369385 @staticmethod
370386 def inputs_pre_hook (inputs : List [torch .Tensor ]):
371- # always set the first element to bo iota in x
387+ # always set the first element to be the number of tokens in x
372388 x , w = inputs
373389 x_hooked = torch .zeros_like (x )
374390 x_hooked [- 1 , 0 ] = x .shape [0 ]
@@ -389,13 +405,29 @@ def test_autotuner_tuning_configs():
389405 # Test if the number of tuning tokens is clipped to 32
390406 tune_max_num_tokens = GemmRunnerComplexTuningConfigs .tune_max_num_tokens ,
391407 inputs_pre_hook = GemmRunnerComplexTuningConfigs .inputs_pre_hook ,
408+ use_cold_l2_cache = True ,
409+ use_cuda_graph = False ,
392410 )
393- with autotune ():
411+ temp_dir = tempfile .TemporaryDirectory ()
412+ with autotune (cache_path = os .path .join (
413+ temp_dir .name , "test_autotuner_tactic_configs.json" )):
394414 tuner = AutoTuner .get ()
395- runner , tactic = tuner .choose_one ("test_autotuner_tactic_configs" ,
396- runners , tuning_config , [x , w ])
415+ runner , best_tactic = tuner .choose_one ("test_autotuner_tactic_configs" ,
416+ runners , tuning_config , [x , w ])
417+
418+ runner_0 ([x , w ], tactic = best_tactic )
419+
420+ # Test if the tactic can be loaded from cache correctly
421+ AutoTuner .get ().profiling_cache .clear ()
422+ AutoTuner .get ().profiling_cache .load_cache (
423+ os .path .join (temp_dir .name , "test_autotuner_tactic_configs.rank0.json" ))
424+
425+ # No further tuning should be performed.
426+ runner , deserialized_tactic = tuner .choose_one (
427+ "test_autotuner_tactic_configs" , runners , tuning_config , [x , w ])
428+ assert best_tactic == deserialized_tactic , "Tactic should be the same after deserialization"
397429
398- runner_0 . forward ( inputs = [x , w ], tactic = tactic )
430+ runner_0 ( [x , w ], tactic = deserialized_tactic )
399431
400432
401433def test_kernel_testing_single_context ():
0 commit comments