1818from absl .testing import absltest
1919from absl .testing import parameterized
2020import jax
21+ from jax ._src import test_util as jtu
2122from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask as mask_lib
2223from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask_info as mask_info_lib
23- from jax ._src import test_util as jtu
2424import numpy as np
2525
2626jax .config .parse_flags_with_absl ()
@@ -798,7 +798,7 @@ def test_two_causal_masks(self, is_lazy_mask: bool):
798798 self ._expected_causal_data_next [None ],
799799 self ._expected_causal_mask_next (0 )[None ] if not is_lazy_mask else None ,
800800 self ._expected_causal_block_mask [None ],
801- np .expand_dims (np .tril (np .ones (block_shape , dtype = np .int32 )), 0 )
801+ np .expand_dims (np .tril (np .ones (block_shape , dtype = np .bool_ )), 0 )
802802 if not is_lazy_mask
803803 else None ,
804804 np .arange (sequence_lengths [0 ], dtype = np .int32 )
@@ -813,7 +813,7 @@ def test_two_causal_masks(self, is_lazy_mask: bool):
813813 else None ,
814814 self ._expected_causal_block_mask_dkv [None ],
815815 np .expand_dims (
816- np .tril (np .ones (block_shape , dtype = np .int32 )), 0
816+ np .tril (np .ones (block_shape , dtype = np .bool_ )), 0
817817 ).swapaxes (- 1 , - 2 )
818818 if not is_lazy_mask
819819 else None ,
@@ -851,7 +851,7 @@ def test_rectangular_wide_causal_mask(self, is_lazy_mask: bool):
851851 self ._expected_causal_data_next [None ],
852852 self ._expected_causal_mask_next (0 )[None ] if not is_lazy_mask else None ,
853853 self ._expected_causal_block_mask [None ],
854- np .expand_dims (np .tril (np .ones (block_shape , dtype = np .int32 )), 0 )
854+ np .expand_dims (np .tril (np .ones (block_shape , dtype = np .bool_ )), 0 )
855855 if not is_lazy_mask
856856 else None ,
857857 np .arange (sequence_lengths [0 ], dtype = np .int32 )
@@ -894,7 +894,7 @@ def test_rectangular_wide_causal_mask(self, is_lazy_mask: bool):
894894 expected_causal_mask_next_dkv if not is_lazy_mask else None ,
895895 expected_causal_block_mask_dkv ,
896896 np .expand_dims (
897- np .tril (np .ones (block_shape , dtype = np .int32 )), 0
897+ np .tril (np .ones (block_shape , dtype = np .bool_ )), 0
898898 ).swapaxes (- 1 , - 2 )
899899 if not is_lazy_mask
900900 else None ,
@@ -974,7 +974,7 @@ def test_rectangular_tall_causal_mask(self, is_lazy_mask: bool):
974974 expected_causal_data_next ,
975975 expected_causal_mask_next if not is_lazy_mask else None ,
976976 expected_causal_block_mask ,
977- np .expand_dims (np .tril (np .ones (block_shape , dtype = np .int32 )), 0 )
977+ np .expand_dims (np .tril (np .ones (block_shape , dtype = np .bool_ )), 0 )
978978 if not is_lazy_mask
979979 else None ,
980980 np .arange (sequence_lengths [0 ], dtype = np .int32 )
@@ -1029,7 +1029,7 @@ def test_rectangular_tall_causal_mask(self, is_lazy_mask: bool):
10291029 expected_causal_mask_next_dkv if not is_lazy_mask else None ,
10301030 expected_causal_block_mask_dkv ,
10311031 np .expand_dims (
1032- np .tril (np .ones (block_shape , dtype = np .int32 )), 0
1032+ np .tril (np .ones (block_shape , dtype = np .bool_ )), 0
10331033 ).swapaxes (- 1 , - 2 )
10341034 if not is_lazy_mask
10351035 else None ,
@@ -1069,10 +1069,10 @@ def test_local_mask(self, is_lazy_mask: bool):
10691069 expected_partial_mask_blocks = self ._stack (
10701070 [
10711071 np .triu (
1072- np .tri (* block_shape , window_size , dtype = np .int32 ), - window_size
1072+ np .tri (* block_shape , window_size , dtype = np .bool_ ), - window_size
10731073 ),
1074- np .tri (* block_shape , - window_size , dtype = np .int32 ),
1075- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1074+ np .tri (* block_shape , - window_size , dtype = np .bool_ ),
1075+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
10761076 ],
10771077 )
10781078
@@ -1179,8 +1179,8 @@ def test_local_mask_narrow(self, is_lazy_mask: bool):
11791179
11801180 expected_partial_mask_blocks = self ._stack (
11811181 [
1182- np .triu (np .tri (* block_shape , 0 , dtype = np .int32 ), - window_size ),
1183- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1182+ np .triu (np .tri (* block_shape , 0 , dtype = np .bool_ ), - window_size ),
1183+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
11841184 ],
11851185 )
11861186
@@ -1298,13 +1298,13 @@ def test_two_head_shards_one_causal_one_local(self, is_lazy_mask: bool):
12981298 )
12991299
13001300 expected_partial_mask_blocks = self ._stack ([
1301- np .tril (np .ones (block_shape , dtype = np .int32 )),
1301+ np .tril (np .ones (block_shape , dtype = np .bool_ )),
13021302 np .triu (
1303- np .tri (* block_shape , window_size , dtype = np .int32 ),
1303+ np .tri (* block_shape , window_size , dtype = np .bool_ ),
13041304 - window_size ,
13051305 ),
1306- np .tri (* block_shape , - window_size , dtype = np .int32 ),
1307- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1306+ np .tri (* block_shape , - window_size , dtype = np .bool_ ),
1307+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
13081308 ])
13091309
13101310 expected_block_mask_dkv = self ._stack (
@@ -1384,7 +1384,7 @@ def test_two_head_shards_causal_full(self, is_lazy_mask: bool):
13841384 ])
13851385
13861386 expected_partial_mask_blocks = np .expand_dims (
1387- np .tril (np .ones (block_shape , dtype = np .int32 )), 0
1387+ np .tril (np .ones (block_shape , dtype = np .bool_ )), 0
13881388 )
13891389
13901390 expected_mask_info = mask_info_lib .MaskInfo (
@@ -1460,13 +1460,13 @@ def test_two_qseq_shards_causal_local(self, is_lazy_mask: bool):
14601460 )
14611461
14621462 expected_partial_mask_blocks = self ._stack ([
1463- np .tril (np .ones (block_shape , dtype = np .int32 )),
1463+ np .tril (np .ones (block_shape , dtype = np .bool_ )),
14641464 np .triu (
1465- np .tri (* block_shape , window_size , dtype = np .int32 ),
1465+ np .tri (* block_shape , window_size , dtype = np .bool_ ),
14661466 - window_size ,
14671467 ),
1468- np .tri (* block_shape , - window_size , dtype = np .int32 ),
1469- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1468+ np .tri (* block_shape , - window_size , dtype = np .bool_ ),
1469+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
14701470 ])
14711471
14721472 expected_mask_info = mask_info_lib .MaskInfo (
@@ -1577,13 +1577,13 @@ def test_two_qseq_shards_causal_local_stacked(self):
15771577 )
15781578
15791579 expected_partial_mask_blocks = self ._stack ([
1580- np .tril (np .ones (block_shape , dtype = np .int32 )),
1580+ np .tril (np .ones (block_shape , dtype = np .bool_ )),
15811581 np .triu (
1582- np .tri (* block_shape , window_size , dtype = np .int32 ),
1582+ np .tri (* block_shape , window_size , dtype = np .bool_ ),
15831583 - window_size ,
15841584 ),
1585- np .tri (* block_shape , - window_size , dtype = np .int32 ),
1586- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1585+ np .tri (* block_shape , - window_size , dtype = np .bool_ ),
1586+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
15871587 ])
15881588
15891589 expected_mask_info = mask_info_lib .MaskInfo (
@@ -1749,13 +1749,13 @@ def test_two_qseq_shards_local_wide_local_narrow_stacked(self):
17491749 expected_partial_mask_blocks = self ._stack ([
17501750 # Wide
17511751 np .triu (
1752- np .tri (* block_shape , window_size , dtype = np .int32 ),
1752+ np .tri (* block_shape , window_size , dtype = np .bool_ ),
17531753 - window_size ,
17541754 ),
1755- np .tri (* block_shape , - window_size , dtype = np .int32 ),
1756- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1755+ np .tri (* block_shape , - window_size , dtype = np .bool_ ),
1756+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
17571757 # Narrow
1758- np .triu (np .tri (* block_shape , 0 , dtype = np .int32 ), - window_size ),
1758+ np .triu (np .tri (* block_shape , 0 , dtype = np .bool_ ), - window_size ),
17591759 ])
17601760
17611761 expected_mask_info = mask_info_lib .MaskInfo (
@@ -1890,7 +1890,7 @@ def test_two_head_shards_causal_mask(self, is_lazy_mask: bool):
18901890 )
18911891
18921892 expected_partial_mask_blocks = np .expand_dims (
1893- np .tril (np .ones (block_shape , dtype = np .int32 )), 0
1893+ np .tril (np .ones (block_shape , dtype = np .bool_ )), 0
18941894 )
18951895
18961896 expected_mask_info = mask_info_lib .MaskInfo (
@@ -1979,13 +1979,13 @@ def test_two_head_shards_two_causal_two_local(self, is_lazy_mask: bool):
19791979
19801980 expected_partial_mask_blocks = self ._stack (
19811981 [
1982- np .tril (np .ones (block_shape , dtype = np .int32 )),
1982+ np .tril (np .ones (block_shape , dtype = np .bool_ )),
19831983 np .triu (
1984- np .tri (* block_shape , window_size , dtype = np .int32 ),
1984+ np .tri (* block_shape , window_size , dtype = np .bool_ ),
19851985 - window_size ,
19861986 ),
1987- np .tri (* block_shape , - window_size , dtype = np .int32 ),
1988- np .triu (np .ones (block_shape , dtype = np .int32 ), window_size ),
1987+ np .tri (* block_shape , - window_size , dtype = np .bool_ ),
1988+ np .triu (np .ones (block_shape , dtype = np .bool_ ), window_size ),
19891989 ],
19901990 )
19911991
0 commit comments