Skip to content

Commit b496613

Browse files
Compute tile index using tile-based coordinates
This reduces the chances of overflowing a 32-bit integer when computing tile indices. Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`. PiperOrigin-RevId: 737778853
1 parent b74b16f commit b496613

File tree

2 files changed

+78
-21
lines changed

2 files changed

+78
-21
lines changed

jax/_src/blocked_sampler.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ def __call__(self, key: ArrayLike, *args, shape: Shape,
2828
...
2929

3030

31-
def _compute_scalar_index(iteration_index: Sequence[int],
32-
total_size: Shape,
33-
block_size: Shape,
34-
block_index: Sequence[int]) -> int:
35-
ndims = len(iteration_index)
31+
def _compute_tile_index(block_index: Sequence[int],
32+
total_size_in_blocks: Shape,
33+
block_size_in_tiles: Shape,
34+
tile_index_in_block: Sequence[int]) -> int:
35+
ndims = len(block_index)
3636
dim_size = 1
3737
total_idx = 0
3838
for i in range(ndims-1, -1, -1):
39-
dim_idx = block_index[i] + iteration_index[i] * block_size[i]
39+
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
4040
total_idx += dim_idx * dim_size
41-
dim_size *= total_size[i]
41+
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
4242
return total_idx
4343

4444

@@ -99,18 +99,23 @@ def blocked_fold_in(
9999
An N-dimensional nested list of keys required to sample the tiles
100100
corresponding to the block specified by `block_index`.
101101
"""
102-
size_in_blocks = tuple(
103-
_shape // _element for _shape, _element in zip(block_size, tile_size))
102+
block_size_in_tiles = tuple(
103+
_shape // _element for _shape, _element in zip(block_size, tile_size)
104+
)
105+
106+
total_size_in_blocks = tuple(
107+
_shape // _element for _shape, _element in zip(total_size, block_size)
108+
)
104109

105110
def _keygen_loop(axis, prefix):
106-
if axis == len(size_in_blocks):
111+
if axis == len(block_size_in_tiles):
107112
subtile_key = jax.random.fold_in(
108-
global_key, _compute_scalar_index(
109-
block_index, total_size, size_in_blocks, prefix))
113+
global_key, _compute_tile_index(
114+
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
110115
return subtile_key
111116
else:
112117
keys = []
113-
for i in range(size_in_blocks[axis]):
118+
for i in range(block_size_in_tiles[axis]):
114119
keys.append(_keygen_loop(axis+1, prefix+(i,)))
115120
return keys
116121
return _keygen_loop(0, tuple())

tests/blocked_sampler_test.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,41 @@ def call_kernel(
3737
m, n = grid
3838
return jnp.concatenate([
3939
jnp.concatenate([
40-
kernel(i, j, *args) for j in range(n)], axis=1)
40+
kernel((i, j), *args) for j in range(n)], axis=1)
4141
for i in range(m)], axis=0)
4242

4343

44-
def uniform_kernel(i: int, j: int, total_size, block_size, tile_size):
45-
"""Uniform random sampling kernel function."""
46-
global_key = jax.random.key(0)
47-
keys = blocked_sampler.blocked_fold_in(global_key,
44+
def call_kernel_3d(
45+
kernel,
46+
grid: tuple[int, int],
47+
*args
48+
):
49+
"""Calls a kernel over a 3D grid and concatenates results to a single array."""
50+
depth, rows, cols = grid
51+
return jnp.concatenate([
52+
jnp.concatenate([
53+
jnp.concatenate([
54+
jnp.array(kernel((i, j, k), *args))
55+
for k in range(cols)], axis=2)
56+
for j in range(rows)], axis=1)
57+
for i in range(depth)], axis=0)
58+
59+
60+
def blocked_fold_in(block_index, key, total_size, block_size, tile_size):
61+
"""Folds in block_index into global_key."""
62+
return blocked_sampler.blocked_fold_in(key,
4863
total_size=total_size,
4964
block_size=block_size,
5065
tile_size=tile_size,
51-
block_index=(i, j))
66+
block_index=block_index)
67+
68+
69+
def uniform_kernel(block_index, key, total_size, block_size, tile_size):
70+
"""Uniform random sampling kernel function."""
71+
keys = blocked_fold_in(block_index, key,
72+
total_size=total_size,
73+
block_size=block_size,
74+
tile_size=tile_size)
5275
return blocked_sampler.sample_block(jax.random.uniform,
5376
keys,
5477
block_size=block_size,
@@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
7497
)
7598
def test_block_shape_invariance(self, total_size, block_size_a,
7699
block_size_b, tile_size, transpose_grid):
100+
global_key = jax.random.key(0)
77101
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
78102
result_a = call_kernel(
79-
uniform_kernel, grid_a, transpose_grid,
103+
uniform_kernel, grid_a, transpose_grid, global_key,
80104
total_size, block_size_a, tile_size)
81105

82106
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
83107
result_b = call_kernel(
84-
uniform_kernel, grid_b, transpose_grid,
108+
uniform_kernel, grid_b, transpose_grid, global_key,
85109
total_size, block_size_b, tile_size)
86110
np.testing.assert_array_equal(result_a, result_b)
87111

88112

113+
class BlockedFoldInTest(jtu.JaxTestCase):
114+
@parameterized.named_parameters(
115+
# Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
116+
# as expected. Specifically, blocked key folding does not depend on the total
117+
# size of the tensor, but only the total number of tiles.
118+
# Using a 3D grid (with very large inner dimensions) triggers an overflow in a
119+
# previous implementation of blocked_fold_in.
120+
dict(testcase_name='4096x512_vs_1024x2048',
121+
total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512),
122+
block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)),
123+
)
124+
def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a,
125+
block_size_b, tile_size):
126+
global_key = jax.random.key(0)
127+
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
128+
result_a = call_kernel_3d(
129+
blocked_fold_in, grid_a, global_key, total_size,
130+
block_size_a, tile_size)
131+
132+
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
133+
result_b = call_kernel_3d(
134+
blocked_fold_in, grid_b, global_key, total_size,
135+
block_size_b, tile_size)
136+
np.testing.assert_array_equal(jax.random.key_data(result_a),
137+
jax.random.key_data(result_b))
138+
139+
140+
89141
if __name__ == "__main__":
90142
absltest.main()

0 commit comments

Comments
 (0)