Skip to content

Commit d9a48dd

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Fix an overflow issue in TransposePlan
PiperOrigin-RevId: 715162849
1 parent 3aa5d48 commit d9a48dd

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

xla/pjrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ xla_cc_test(
768768
"//xla:shape_util",
769769
"//xla:util",
770770
"//xla/hlo/testlib:test",
771+
"//xla/tsl/lib/core:status_test_util",
771772
"//xla/tsl/protobuf:error_codes_proto_impl_cc",
772773
"@com_google_absl//absl/container:inlined_vector",
773774
"@com_google_absl//absl/numeric:int128",

xla/pjrt/transpose.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ void TransposePlan::BuildPlanNodes(
900900

901901
absl::StatusOr<std::unique_ptr<TransposePlan>> TransposePlan::Create(
902902
const Options& o) {
903-
auto is_negative = [](int d) { return d < 0; };
903+
auto is_negative = [](int64_t d) { return d < 0; };
904904
if (absl::c_find_if(o.dims, is_negative) != o.dims.end()) {
905905
return InvalidArgument("dims must be non-negative, got %s",
906906
absl::StrJoin(o.dims, ","));

xla/pjrt/transpose_test.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ limitations under the License.
3333
#include "xla/hlo/testlib/test.h"
3434
#include "xla/permutation_util.h"
3535
#include "xla/shape_util.h"
36+
#include "xla/tsl/lib/core/status_test_util.h"
3637
#include "xla/tsl/protobuf/error_codes.pb.h"
3738
#include "xla/util.h"
3839
#include "tsl/platform/statusor.h"
@@ -140,6 +141,20 @@ TEST(TransposeTest, InvalidTilings) {
140141
"Only one of the input and output may have a non-trivial tiling"));
141142
}
142143

144+
TEST(TransposeTest, LargeDimensions) {
145+
std::vector<int64_t> dims = {3ll << 30};
146+
std::vector<int64_t> permutation = {0};
147+
148+
TransposePlan::Options options;
149+
options.elem_size_in_bytes = 8;
150+
options.dims = dims;
151+
options.permutation = permutation;
152+
options.input_layout = TransposePlan::Tiling{};
153+
options.output_tiling = TransposePlan::Tiling{};
154+
options.transformation = TransposePlan::Transformation::kNone;
155+
TF_EXPECT_OK(TransposePlan::Create(options).status());
156+
}
157+
143158
// Computes the size in elements of a tiled array.
144159
int64_t SizeOfTiledArray(absl::Span<int64_t const> shape,
145160
absl::Span<int64_t const> tiling) {

0 commit comments

Comments
 (0)