Skip to content

Commit 0ba4b05

Browse files
chunnienccopybara-github
authored andcommitted
fix py3.13 testing failures
PiperOrigin-RevId: 824686696
1 parent f74158b commit 0ba4b05

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

ai_edge_torch/generative/test/test_kv_cache.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def _get_test_config(self, num_layers, head_dim, num_query_groups):
4141
)
4242
return config
4343

44+
def _assert_kv_cache_entry_equal(self, kv1, kv2):
45+
self.assertIsInstance(kv1, kv_utils.KVCacheEntry)
46+
self.assertIsInstance(kv2, kv_utils.KVCacheEntry)
47+
self.assertEqual(kv1.kv_layout, kv2.kv_layout)
48+
self.assertTrue(torch.equal(kv1.k_cache, kv2.k_cache))
49+
self.assertTrue(torch.equal(kv1.v_cache, kv2.v_cache))
50+
51+
def _assert_kv_cache_equal(self, kv1, kv2):
52+
self.assertIsInstance(kv1, kv_utils.KVCache)
53+
self.assertIsInstance(kv2, kv_utils.KVCache)
54+
self.assertEqual(len(kv1.caches), len(kv2.caches))
55+
for kv1_entry, kv2_entry in zip(kv1.caches, kv2.caches):
56+
self._assert_kv_cache_entry_equal(kv1_entry, kv2_entry)
57+
4458
def test_cache_udpate(self):
4559
N = 1
4660
HEAD_DIM = 2
@@ -118,7 +132,7 @@ def test_pytree_roundtrip_kv_cache(self):
118132
flat, treespec = pytree.tree_flatten(kv)
119133
self.assertLen(flat, NUM_LAYERS * 2)
120134
kv_unflat = pytree.tree_unflatten(flat, treespec)
121-
self.assertEqual(kv, kv_unflat)
135+
self._assert_kv_cache_equal(kv, kv_unflat)
122136

123137
def test_pytree_roundtrip_kv_cache_derived(self):
124138
NUM_LAYERS = 4
@@ -134,7 +148,7 @@ def test_pytree_roundtrip_kv_cache_derived(self):
134148
flat, treespec = pytree.tree_flatten(kv)
135149
self.assertLen(flat, NUM_LAYERS * 2)
136150
kv_unflat = pytree.tree_unflatten(flat, treespec)
137-
self.assertEqual(kv, kv_unflat)
151+
self._assert_kv_cache_equal(kv, kv_unflat)
138152

139153
def test_pytree_roundtrip_kv_entry(self):
140154
attn_config = cfg.AttentionConfig(
@@ -144,8 +158,7 @@ def test_pytree_roundtrip_kv_entry(self):
144158
flat, treespec = pytree.tree_flatten(kv)
145159
self.assertLen(flat, 2)
146160
kv_unflat = pytree.tree_unflatten(flat, treespec)
147-
self.assertEqual(kv, kv_unflat)
148-
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
161+
self._assert_kv_cache_entry_equal(kv, kv_unflat)
149162

150163
def test_pytree_roundtrip_kv_entry_derived(self):
151164
attn_config = cfg.AttentionConfig(
@@ -157,8 +170,7 @@ def test_pytree_roundtrip_kv_entry_derived(self):
157170
flat, treespec = pytree.tree_flatten(kv)
158171
self.assertLen(flat, 2)
159172
kv_unflat = pytree.tree_unflatten(flat, treespec)
160-
self.assertEqual(kv, kv_unflat)
161-
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
173+
self._assert_kv_cache_entry_equal(kv, kv_unflat)
162174

163175

164176
if __name__ == "__main__":

0 commit comments

Comments
 (0)