@@ -29,16 +29,23 @@ def call_kernel(
2929 kernel ,
3030 grid : tuple [int , int ],
3131 transpose_grid : bool ,
32- * args
32+ key : jax .Array ,
33+ total_size : tuple [int , int ],
34+ block_size : tuple [int , int ],
35+ tile_size : tuple [int , int ],
3336 ):
3437 """Calls a kernel over a grid and concatenates results to a single array."""
3538 if transpose_grid :
3639 grid = (grid [1 ], grid [0 ])
3740 m , n = grid
38- return jnp .concatenate ([
39- jnp .concatenate ([
40- kernel ((i , j ), * args ) for j in range (n )], axis = 1 )
41- for i in range (m )], axis = 0 )
41+ samples = jnp .concatenate ([
42+ jnp .concatenate ([
43+ kernel ((i , j ), key , total_size , block_size , tile_size )
44+ for j in range (n )], axis = 1 )
45+ for i in range (m )], axis = 0 )
46+ # Slice out the padding.
47+ samples = samples [:total_size [0 ], :total_size [1 ]]
48+ return samples
4249
4350
4451def call_kernel_3d (
@@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
7380 block_size = block_size ,
7481 tile_size = tile_size )
7582 return blocked_sampler .sample_block (jax .random .uniform ,
76- keys ,
77- block_size = block_size ,
78- tile_size = tile_size ,
79- minval = 0.0 , maxval = 1.0 )
83+ keys ,
84+ block_size = block_size ,
85+ tile_size = tile_size ,
86+ minval = 0.0 , maxval = 1.0 )
8087
8188
8289class BlockedSamplerTest (jtu .JaxTestCase ):
@@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
94101 dict (testcase_name = '16x256_vs_32x128' , total_size = (32 , 256 ),
95102 block_size_a = (16 , 256 ), block_size_b = (32 , 128 ),
96103 tile_size = (8 , 128 ), transpose_grid = False ),
104+ dict (testcase_name = '128x128_vs_128x256_padding' ,
105+ total_size = (256 , 128 ), block_size_a = (128 , 128 ),
106+ block_size_b = (128 , 256 ), tile_size = (128 , 128 ), transpose_grid = False ),
107+ dict (testcase_name = '128x128_vs_128x256_padding2' ,
108+ total_size = (257 , 129 ), block_size_a = (128 , 128 ),
109+ block_size_b = (128 , 256 ), tile_size = (128 , 128 ), transpose_grid = False ),
97110 )
98111 def test_block_shape_invariance (self , total_size , block_size_a ,
99112 block_size_b , tile_size , transpose_grid ):
100113 global_key = jax .random .key (0 )
101- grid_a = tuple (_tot // _blk for _tot , _blk in zip (total_size , block_size_a ))
114+ ceil_div = lambda x , y : (x + y - 1 ) // y
115+ grid_a = tuple (ceil_div (_tot , _blk )
116+ for _tot , _blk in zip (total_size , block_size_a ))
102117 result_a = call_kernel (
103118 uniform_kernel , grid_a , transpose_grid , global_key ,
104119 total_size , block_size_a , tile_size )
105120
106- grid_b = tuple (_tot // _blk for _tot , _blk in zip (total_size , block_size_b ))
121+ grid_b = tuple (ceil_div (_tot , _blk )
122+ for _tot , _blk in zip (total_size , block_size_b ))
107123 result_b = call_kernel (
108124 uniform_kernel , grid_b , transpose_grid , global_key ,
109125 total_size , block_size_b , tile_size )
0 commit comments