Skip to content

Commit f3bb52b

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic][util] Fix ComputeTileStrides for 0D tiles
Previously, `tiling.size() - 1` could underflow for 0D tiles, leading to and OOB access and crash PiperOrigin-RevId: 833506958
1 parent 2cc6522 commit f3bb52b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

jaxlib/mosaic/dialect/tpu/util.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <algorithm>
1919
#include <array>
20+
#include <cstddef>
2021
#include <cstdint>
2122
#include <optional>
2223
#include <ostream>
@@ -53,13 +54,14 @@ std::ostream &operator<<(std::ostream &os, Print p) {
5354

5455
SmallVector<int64_t> ComputeTileStrides(absl::Span<const int64_t> shape,
5556
absl::Span<const int64_t> tiling) {
57+
CHECK_LE(tiling.size(), shape.size());
5658
SmallVector<int64_t> tile_strides(shape.size());
5759
int64_t stride = 1;
58-
for (int64_t i = 0; i < shape.size(); ++i) {
59-
int64_t idx = shape.size() - 1 - i;
60-
int64_t tiling_idx = tiling.size() - 1 - i;
60+
for (size_t i = 0; i < shape.size(); ++i) {
61+
const size_t idx = shape.size() - 1 - i;
6162
tile_strides[idx] = stride;
62-
if (tiling_idx >= 0) {
63+
if (i < tiling.size()) {
64+
const size_t tiling_idx = tiling.size() - 1 - i;
6365
stride *= llvm::divideCeil(shape[idx], tiling[tiling_idx]);
6466
} else {
6567
stride *= shape[idx];

0 commit comments

Comments
 (0)