Skip to content

Commit 13541e9

Browse files
Make blocked_fold_in consistent when the block sizes induce padding
Add coverage for padded shapes to unit tests. PiperOrigin-RevId: 738029476
1 parent 1e36cbe commit 13541e9

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

jax/_src/blocked_sampler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ def __call__(self, key: ArrayLike, *args, shape: Shape,
2929

3030

3131
def _compute_tile_index(block_index: Sequence[int],
32-
total_size_in_blocks: Shape,
3332
block_size_in_tiles: Shape,
33+
total_size_in_tiles: Shape,
3434
tile_index_in_block: Sequence[int]) -> int:
3535
ndims = len(block_index)
3636
dim_size = 1
3737
total_idx = 0
3838
for i in range(ndims-1, -1, -1):
3939
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
4040
total_idx += dim_idx * dim_size
41-
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
41+
dim_size *= total_size_in_tiles[i]
4242
return total_idx
4343

4444

@@ -103,15 +103,17 @@ def blocked_fold_in(
103103
_shape // _element for _shape, _element in zip(block_size, tile_size)
104104
)
105105

106-
total_size_in_blocks = tuple(
107-
_shape // _element for _shape, _element in zip(total_size, block_size)
106+
# Round up to make sure every tile is numbered.
107+
total_size_in_tiles = tuple(
108+
(_shape + _element - 1) // _element
109+
for _shape, _element in zip(total_size, tile_size)
108110
)
109111

110112
def _keygen_loop(axis, prefix):
111113
if axis == len(block_size_in_tiles):
112114
subtile_key = jax.random.fold_in(
113115
global_key, _compute_tile_index(
114-
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
116+
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
115117
return subtile_key
116118
else:
117119
keys = []

tests/blocked_sampler_test.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,23 @@ def call_kernel(
2929
kernel,
3030
grid: tuple[int, int],
3131
transpose_grid: bool,
32-
*args
32+
key: jax.Array,
33+
total_size: tuple[int, int],
34+
block_size: tuple[int, int],
35+
tile_size: tuple[int, int],
3336
):
3437
"""Calls a kernel over a grid and concatenates results to a single array."""
3538
if transpose_grid:
3639
grid = (grid[1], grid[0])
3740
m, n = grid
38-
return jnp.concatenate([
39-
jnp.concatenate([
40-
kernel((i, j), *args) for j in range(n)], axis=1)
41-
for i in range(m)], axis=0)
41+
samples = jnp.concatenate([
42+
jnp.concatenate([
43+
kernel((i, j), key, total_size, block_size, tile_size)
44+
for j in range(n)], axis=1)
45+
for i in range(m)], axis=0)
46+
# Slice out the padding.
47+
samples = samples[:total_size[0], :total_size[1]]
48+
return samples
4249

4350

4451
def call_kernel_3d(
@@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
7380
block_size=block_size,
7481
tile_size=tile_size)
7582
return blocked_sampler.sample_block(jax.random.uniform,
76-
keys,
77-
block_size=block_size,
78-
tile_size=tile_size,
79-
minval=0.0, maxval=1.0)
83+
keys,
84+
block_size=block_size,
85+
tile_size=tile_size,
86+
minval=0.0, maxval=1.0)
8087

8188

8289
class BlockedSamplerTest(jtu.JaxTestCase):
@@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
94101
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
95102
block_size_a=(16, 256), block_size_b=(32, 128),
96103
tile_size=(8, 128), transpose_grid=False),
104+
dict(testcase_name='128x128_vs_128x256_padding',
105+
total_size=(256, 128), block_size_a=(128, 128),
106+
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
107+
dict(testcase_name='128x128_vs_128x256_padding2',
108+
total_size=(257, 129), block_size_a=(128, 128),
109+
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
97110
)
98111
def test_block_shape_invariance(self, total_size, block_size_a,
99112
block_size_b, tile_size, transpose_grid):
100113
global_key = jax.random.key(0)
101-
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
114+
ceil_div = lambda x, y: (x + y - 1) // y
115+
grid_a = tuple(ceil_div(_tot, _blk)
116+
for _tot, _blk in zip(total_size, block_size_a))
102117
result_a = call_kernel(
103118
uniform_kernel, grid_a, transpose_grid, global_key,
104119
total_size, block_size_a, tile_size)
105120

106-
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
121+
grid_b = tuple(ceil_div(_tot, _blk)
122+
for _tot, _blk in zip(total_size, block_size_b))
107123
result_b = call_kernel(
108124
uniform_kernel, grid_b, transpose_grid, global_key,
109125
total_size, block_size_b, tile_size)

0 commit comments

Comments
 (0)