Commit 21e98b5
Fix overflow error in GPU batched linear algebra kernels.
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.
PiperOrigin-RevId: 6956946481 parent 9bb6366 commit 21e98b5
File tree
3 files changed
+19
-5
lines changed- jaxlib/gpu
- tests
3 files changed
+19
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
24 | 25 | | |
25 | 26 | | |
26 | 27 | | |
27 | | - | |
28 | | - | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
29 | 31 | | |
30 | 32 | | |
31 | 33 | | |
32 | 34 | | |
33 | 35 | | |
34 | 36 | | |
35 | 37 | | |
36 | | - | |
37 | | - | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
38 | 41 | | |
39 | 42 | | |
40 | 43 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
22 | 24 | | |
23 | 25 | | |
24 | 26 | | |
25 | | - | |
| 27 | + | |
| 28 | + | |
26 | 29 | | |
27 | 30 | | |
28 | 31 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1450 | 1450 | | |
1451 | 1451 | | |
1452 | 1452 | | |
| 1453 | + | |
| 1454 | + | |
| 1455 | + | |
| 1456 | + | |
| 1457 | + | |
| 1458 | + | |
| 1459 | + | |
| 1460 | + | |
1453 | 1461 | | |
1454 | 1462 | | |
1455 | 1463 | | |
| |||
0 commit comments