File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed
jaxlib/mosaic/dialect/tpu Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License.
1717
1818#include < algorithm>
1919#include < array>
20+ #include < cstddef>
2021#include < cstdint>
2122#include < optional>
2223#include < ostream>
@@ -53,13 +54,14 @@ std::ostream &operator<<(std::ostream &os, Print p) {
5354
5455SmallVector<int64_t > ComputeTileStrides (absl::Span<const int64_t > shape,
5556 absl::Span<const int64_t > tiling) {
57+ CHECK_LE (tiling.size (), shape.size ());
5658 SmallVector<int64_t > tile_strides (shape.size ());
5759 int64_t stride = 1 ;
58- for (int64_t i = 0 ; i < shape.size (); ++i) {
59- int64_t idx = shape.size () - 1 - i;
60- int64_t tiling_idx = tiling.size () - 1 - i;
60+ for (size_t i = 0 ; i < shape.size (); ++i) {
61+ const size_t idx = shape.size () - 1 - i;
6162 tile_strides[idx] = stride;
62- if (tiling_idx >= 0 ) {
63+ if (i < tiling.size ()) {
64+ const size_t tiling_idx = tiling.size () - 1 - i;
6365 stride *= llvm::divideCeil (shape[idx], tiling[tiling_idx]);
6466 } else {
6567 stride *= shape[idx];
You can’t perform that action at this time.
0 commit comments