@@ -37,18 +37,41 @@ def call_kernel(
3737 m , n = grid
3838 return jnp .concatenate ([
3939 jnp .concatenate ([
40- kernel (i , j , * args ) for j in range (n )], axis = 1 )
40+ kernel (( i , j ) , * args ) for j in range (n )], axis = 1 )
4141 for i in range (m )], axis = 0 )
4242
4343
44- def uniform_kernel (i : int , j : int , total_size , block_size , tile_size ):
45- """Uniform random sampling kernel function."""
46- global_key = jax .random .key (0 )
47- keys = blocked_sampler .blocked_fold_in (global_key ,
44+ def call_kernel_3d (
45+ kernel ,
46+ grid : tuple [int , int ],
47+ * args
48+ ):
49+ """Calls a kernel over a 3D grid and concatenates results to a single array."""
50+ depth , rows , cols = grid
51+ return jnp .concatenate ([
52+ jnp .concatenate ([
53+ jnp .concatenate ([
54+ jnp .array (kernel ((i , j , k ), * args ))
55+ for k in range (cols )], axis = 2 )
56+ for j in range (rows )], axis = 1 )
57+ for i in range (depth )], axis = 0 )
58+
59+
60+ def blocked_fold_in (block_index , key , total_size , block_size , tile_size ):
61+ """Folds in block_index into global_key."""
62+ return blocked_sampler .blocked_fold_in (key ,
4863 total_size = total_size ,
4964 block_size = block_size ,
5065 tile_size = tile_size ,
51- block_index = (i , j ))
66+ block_index = block_index )
67+
68+
69+ def uniform_kernel (block_index , key , total_size , block_size , tile_size ):
70+ """Uniform random sampling kernel function."""
71+ keys = blocked_fold_in (block_index , key ,
72+ total_size = total_size ,
73+ block_size = block_size ,
74+ tile_size = tile_size )
5275 return blocked_sampler .sample_block (jax .random .uniform ,
5376 keys ,
5477 block_size = block_size ,
@@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
7497 )
7598 def test_block_shape_invariance (self , total_size , block_size_a ,
7699 block_size_b , tile_size , transpose_grid ):
100+ global_key = jax .random .key (0 )
77101 grid_a = tuple (_tot // _blk for _tot , _blk in zip (total_size , block_size_a ))
78102 result_a = call_kernel (
79- uniform_kernel , grid_a , transpose_grid ,
103+ uniform_kernel , grid_a , transpose_grid , global_key ,
80104 total_size , block_size_a , tile_size )
81105
82106 grid_b = tuple (_tot // _blk for _tot , _blk in zip (total_size , block_size_b ))
83107 result_b = call_kernel (
84- uniform_kernel , grid_b , transpose_grid ,
108+ uniform_kernel , grid_b , transpose_grid , global_key ,
85109 total_size , block_size_b , tile_size )
86110 np .testing .assert_array_equal (result_a , result_b )
87111
88112
113+ class BlockedFoldInTest (jtu .JaxTestCase ):
114+ @parameterized .named_parameters (
115+ # Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
116+ # as expected. Specifically, blocked key folding does not depend on the total
117+ # size of the tensor, but only the total number of tiles.
118+ # Using a 3D grid (with very large inner dimensions) triggers an overflow in a
119+ # previous implementation of blocked_fold_in.
120+ dict (testcase_name = '4096x512_vs_1024x2048' ,
121+ total_size = (2 , 64 * 1024 , 64 * 1024 ), block_size_a = (1 , 4096 , 512 ),
122+ block_size_b = (1 , 1024 , 2048 ), tile_size = (1 , 1024 , 512 )),
123+ )
124+ def test_blocked_fold_in_shape_invariance (self , total_size , block_size_a ,
125+ block_size_b , tile_size ):
126+ global_key = jax .random .key (0 )
127+ grid_a = tuple (_tot // _blk for _tot , _blk in zip (total_size , block_size_a ))
128+ result_a = call_kernel_3d (
129+ blocked_fold_in , grid_a , global_key , total_size ,
130+ block_size_a , tile_size )
131+
132+ grid_b = tuple (_tot // _blk for _tot , _blk in zip (total_size , block_size_b ))
133+ result_b = call_kernel_3d (
134+ blocked_fold_in , grid_b , global_key , total_size ,
135+ block_size_b , tile_size )
136+ np .testing .assert_array_equal (jax .random .key_data (result_a ),
137+ jax .random .key_data (result_b ))
138+
139+
140+
89141if __name__ == "__main__" :
90142 absltest .main ()
0 commit comments