Skip to content

Commit 21e98b5

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Fix overflow error in GPU batched linear algebra kernels.
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695694648
1 parent 9bb6366 commit 21e98b5

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

jaxlib/gpu/make_batch_pointers.cu.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "jaxlib/gpu/make_batch_pointers.h"
1717

1818
#include <algorithm>
19+
#include <cstdint>
1920

2021
#include "jaxlib/gpu/vendor.h"
2122

@@ -24,17 +25,19 @@ namespace JAX_GPU_NAMESPACE {
2425

2526
namespace {
2627
__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out,
27-
int batch, int batch_elem_size) {
28-
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
28+
int64_t batch,
29+
int64_t batch_elem_size) {
30+
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
2931
idx += blockDim.x * gridDim.x) {
3032
buffer_out[idx] = buffer_in + idx * batch_elem_size;
3133
}
3234
}
3335
} // namespace
3436

3537
void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
36-
void* buffer_out, int batch, int batch_elem_size) {
37-
const int block_dim = 128;
38+
void* buffer_out, int64_t batch,
39+
int64_t batch_elem_size) {
40+
const std::size_t block_dim = 128;
3841
const std::size_t grid_dim =
3942
std::min<std::size_t>(1024, (batch + block_dim - 1) / block_dim);
4043
MakeBatchPointersAsyncKernel<<<grid_dim, block_dim, 0, stream>>>(

jaxlib/gpu/make_batch_pointers.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ limitations under the License.
1616
#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
1717
#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
1818

19+
#include <cstdint>
20+
1921
#include "jaxlib/gpu/vendor.h"
2022

2123
namespace jax {
2224
namespace JAX_GPU_NAMESPACE {
2325

2426
void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
25-
void* buffer_out, int batch, int batch_elem_size);
27+
void* buffer_out, int64_t batch,
28+
int64_t batch_elem_size);
2629

2730
} // namespace JAX_GPU_NAMESPACE
2831
} // namespace jax

tests/linalg_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,14 @@ def testLuBatching(self, shape, dtype):
14501450
self.assertAllClose(ls, actual_ls, rtol=5e-6)
14511451
self.assertAllClose(us, actual_us)
14521452

1453+
@jtu.skip_on_devices("cpu", "tpu")
1454+
@jtu.skip_on_flag("jax_skip_slow_tests", True)
1455+
def testBatchedLuOverflow(self):
1456+
# see https://github.com/jax-ml/jax/issues/24843
1457+
x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32)
1458+
lu, _, _ = lax.linalg.lu(x)
1459+
self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9))
1460+
14531461
@jtu.skip_on_devices("cpu", "tpu")
14541462
@jtu.ignore_warning(category=DeprecationWarning,
14551463
message="backend and device argument")

0 commit comments

Comments
 (0)