@@ -215,34 +215,36 @@ def bucket_mapper(x: int, inflate_factor: int) -> int:
215215 "test_runtime_bucket_mapping" )
216216
217217 # Extract the first dimension of the first input shape from each cache key
218- cached_bucket_sizes = sorted ([key [3 ][0 ][0 ] for key in cache_entries .keys ()])
219- expected_buckets = [1 , 2 , 4 , 8 , 16 , 32 ]
220- assert cached_bucket_sizes == expected_buckets , \
221- f"Expect { expected_buckets } , got { cached_bucket_sizes } "
218+ assert len (cache_entries ) == 6 , \
219+ f"Expected 6 cache entries (buckets 1, 2, 4, 8, 16, 32), got { len (cache_entries )} "
222220
223221 # Test runtime mapping: input size should be mapped via map_to_runtime_buckets
224222 # to find the correct tuning bucket
225223 test_cases = [
226- # input_size, expected_mapped_bucket via round_rule
227- (4 , 1 ), # round_rule(4) = 4 // 4 = 1
228- (8 , 2 ), # round_rule(8) = 8 // 4 = 2
229- (16 , 4 ), # round_rule(16) = 16 // 4 = 4
230- (32 , 8 ), # round_rule(32) = 32 // 4 = 8
231- (64 , 16 ), # round_rule(64) = 64 // 4 = 16
232- (128 , 32 ), # round_rule(128) = 128 // 4 = 32
224+ # size 4 maps to bucket 4//4 = 1, tactic 0 (1 <= M // 2)
225+ (4 , 1 , 0 ),
226+ # size 8 maps to bucket 8//4 = 2, tactic 0 (2 <= M // 2)
227+ (8 , 2 , 0 ),
228+ # size 16 maps to bucket 16//4 = 4, tactic 0 (4 <= M // 2)
229+ (16 , 4 , 0 ),
230+ # size 32 maps to bucket 32//4 = 8, tactic 0 (8 <= M // 2)
231+ (32 , 8 , 0 ),
232+ # size 64 maps to bucket 64//4 = 16, tactic 0 (16 <= M // 2)
233+ (64 , 16 , 0 ),
234+ # size 128 maps to bucket 128//4 = 32, tactic 1 (32 > M // 2)
235+ (128 , 32 , 1 ),
236+ # size 256 maps to bucket 256//4 = 64, tactic -1 (64 > M)
237+ (256 , 64 , - 1 ),
233238 ]
234239
235- for input_size , expected_bucket in test_cases :
236- # Verify the round_rule mapping
237- assert bucket_mapper (input_size , inflate_factor = 4 ) == expected_bucket , \
238- f"bucket_mapper({ input_size } , inflate_factor=4) should be { expected_bucket } , got { bucket_mapper (input_size , inflate_factor = 4 )} "
239-
240+ for input_size , expected_bucket , expected_tactic in test_cases :
240241 # Verify cache lookup succeeds with the mapped bucket
241242 x = torch .randn (input_size , 64 )
242243 runner , tactic = tuner .choose_one ("test_runtime_bucket_mapping" ,
243244 runners , tuning_config , [x , w ])
244- assert tactic != - 1 , \
245- f"Cache miss for input_size={ input_size } , expected to map to bucket { expected_bucket } "
245+ assert (
246+ tactic == expected_tactic
247+ ), f"Cache mismatch for input_size={ input_size } , expected to map to bucket { expected_tactic } but got { tactic } "
246248
247249
248250def test_autotuner_try_block ():
0 commit comments