Skip to content

Commit 8b7cfcb

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
1 parent e33f3fc commit 8b7cfcb

File tree

4 files changed

+41
-8
lines changed

4 files changed

+41
-8
lines changed

jaxlib/gpu/rnn.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <cstddef>
17+
1618
#include "nanobind/nanobind.h"
1719
#include "nanobind/stl/pair.h"
1820
#include "jaxlib/absl_status_casters.h"
@@ -29,7 +31,7 @@ namespace nb = nanobind;
2931
nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
3032
int batch_size, int max_seq_length, float dropout,
3133
bool bidirectional, bool cudnn_allow_tf32,
32-
int workspace_size, int reserve_space_size) {
34+
size_t workspace_size, size_t reserve_space_size) {
3335
return PackDescriptor(RnnDescriptor{
3436
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
3537
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});

jaxlib/gpu/rnn_kernels.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "jaxlib/gpu/rnn_kernels.h"
1717

18+
#include <cstddef>
1819
#include <utility>
1920
#include <vector>
2021

@@ -71,7 +72,7 @@ template <>
7172

7273
namespace JAX_GPU_NAMESPACE {
7374

74-
static absl::StatusOr<std::pair<int, int>>
75+
static absl::StatusOr<std::pair<size_t, size_t>>
7576
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
7677
int num_layers, int batch_size,
7778
int max_seq_length, float dropout,
@@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
174175
return std::make_pair(workSpaceSize, reserveSpaceSize);
175176
}
176177

177-
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
178+
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
178179
int input_size, int hidden_size, int num_layers, int batch_size,
179180
int max_seq_length, float dropout, bool bidirectional,
180181
bool cudnn_allow_tf32) {

jaxlib/gpu/rnn_kernels.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
1717
#define JAXLIB_GPU_RNN_KERNELS_H_
1818

19+
#include <cstddef>
20+
1921
#include "absl/status/statusor.h"
2022
#include "jaxlib/gpu/vendor.h"
2123
#include "xla/ffi/api/ffi.h"
@@ -34,12 +36,12 @@ struct RnnDescriptor {
3436
float dropout;
3537
int bidirectional;
3638
int cudnn_allow_tf32;
37-
int workspace_size;
38-
int reserve_space_size;
39+
size_t workspace_size;
40+
size_t reserve_space_size;
3941
};
4042

4143
// Return (workspace size, reserve space size).
42-
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
44+
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
4345
int input_size, int hidden_size, int num_layers, int batch_size,
4446
int max_seq_length, float dropout, bool bidirectional,
4547
bool cudnn_allow_tf32);

tests/experimental_rnn_test.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,36 @@ def f(k1, k2, k3, k4):
213213

214214
k = jax.random.split(jax.random.PRNGKey(1), 4)
215215
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
216-
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
217-
stablehlo)
216+
if jtu.jaxlib_version() <= (0, 5, 2):
217+
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
218+
stablehlo)
219+
else:
220+
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"',
221+
stablehlo)
222+
223+
@jtu.run_on_devices("cuda")
224+
def test_no_workspace_overflow(self):
225+
if jtu.jaxlib_version() <= (0, 5, 2):
226+
self.skipTest("Older versions fail because of integer overflow.")
227+
228+
# Problem sizes known to cause overflows on older versions.
229+
batch_size, max_seq_length, input_size = 256, 500, 512
230+
num_layers, hidden_size = 1, 256
231+
num_params = rnn.get_num_params_in_lstm(
232+
input_size, hidden_size, num_layers, True)
233+
x = jax.ShapeDtypeStruct(
234+
(batch_size, max_seq_length, input_size), jnp.float32)
235+
h_0 = jax.ShapeDtypeStruct(
236+
(2 * num_layers, batch_size, hidden_size), jnp.float32)
237+
c_0 = jax.ShapeDtypeStruct(
238+
(2 * num_layers, batch_size, hidden_size), jnp.float32)
239+
weights = jax.ShapeDtypeStruct((num_params,), jnp.float32)
240+
seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32)
241+
fun = jax.jit(partial(
242+
rnn.lstm, input_size=input_size, hidden_size=hidden_size,
243+
num_layers=num_layers, dropout=0.0, bidirectional=True))
244+
fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash.
245+
218246

219247
if __name__ == '__main__':
220248
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)