Skip to content

Commit 4bb0114

Browse files
committed
revise test
1 parent d60d2c8 commit 4bb0114

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -215,34 +215,36 @@ def bucket_mapper(x: int, inflate_factor: int) -> int:
215215
"test_runtime_bucket_mapping")
216216

217217
# Extract the first dimension of the first input shape from each cache key
218-
cached_bucket_sizes = sorted([key[3][0][0] for key in cache_entries.keys()])
219-
expected_buckets = [1, 2, 4, 8, 16, 32]
220-
assert cached_bucket_sizes == expected_buckets, \
221-
f"Expect {expected_buckets}, got {cached_bucket_sizes}"
218+
assert len(cache_entries) == 6, \
219+
f"Expected 6 cache entries (buckets 1, 2, 4, 8, 16, 32), got {len(cache_entries)}"
222220

223221
# Test runtime mapping: input size should be mapped via map_to_runtime_buckets
224222
# to find the correct tuning bucket
225223
test_cases = [
226-
# input_size, expected_mapped_bucket via round_rule
227-
(4, 1), # round_rule(4) = 4 // 4 = 1
228-
(8, 2), # round_rule(8) = 8 // 4 = 2
229-
(16, 4), # round_rule(16) = 16 // 4 = 4
230-
(32, 8), # round_rule(32) = 32 // 4 = 8
231-
(64, 16), # round_rule(64) = 64 // 4 = 16
232-
(128, 32), # round_rule(128) = 128 // 4 = 32
224+
# size 4 maps to bucket 4//4 = 1, tactic 0 (1 <= M // 2)
225+
(4, 1, 0),
226+
# size 8 maps to bucket 8//4 = 2, tactic 0 (2 <= M // 2)
227+
(8, 2, 0),
228+
# size 16 maps to bucket 16//4 = 4, tactic 0 (4 <= M // 2)
229+
(16, 4, 0),
230+
# size 32 maps to bucket 32//4 = 8, tactic 0 (8 <= M // 2)
231+
(32, 8, 0),
232+
# size 64 maps to bucket 64//4 = 16, tactic 0 (16 <= M // 2)
233+
(64, 16, 0),
234+
# size 128 maps to bucket 128//4 = 32, tactic 1 (32 > M // 2)
235+
(128, 32, 1),
236+
# size 256 maps to bucket 256//4 = 64, tactic -1 (64 > M)
237+
(256, 64, -1),
233238
]
234239

235-
for input_size, expected_bucket in test_cases:
236-
# Verify the round_rule mapping
237-
assert bucket_mapper(input_size, inflate_factor=4) == expected_bucket, \
238-
f"bucket_mapper({input_size}, inflate_factor=4) should be {expected_bucket}, got {bucket_mapper(input_size, inflate_factor=4)}"
239-
240+
for input_size, expected_bucket, expected_tactic in test_cases:
240241
# Verify cache lookup succeeds with the mapped bucket
241242
x = torch.randn(input_size, 64)
242243
runner, tactic = tuner.choose_one("test_runtime_bucket_mapping",
243244
runners, tuning_config, [x, w])
244-
assert tactic != -1, \
245-
f"Cache miss for input_size={input_size}, expected to map to bucket {expected_bucket}"
245+
assert (
246+
tactic == expected_tactic
247+
), f"Cache mismatch for input_size={input_size}, expected to map to bucket {expected_tactic} but got {tactic}"
246248

247249

248250
def test_autotuner_try_block():

0 commit comments

Comments
 (0)