Skip to content

Commit b27c9d3

Browse files
frgossenGoogle-ML-Automation
authored andcommitted
Support evaluation in the absence of layouts when possible
Only bitcast requires the layout to be known when evaluating HLO. In all other cases, we can evaluate without knowing the layout. This is needed for collective pipelining where we have to analyse while loops before layouts were assigned. PiperOrigin-RevId: 715163612
1 parent d9a48dd commit b27c9d3

File tree

7 files changed

+126
-26
lines changed

7 files changed

+126
-26
lines changed

xla/hlo/evaluator/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ xla_cc_test(
138138
"//xla/hlo/analysis:tuple_points_to_analysis",
139139
"//xla/hlo/builder:xla_builder",
140140
"//xla/hlo/ir:hlo",
141+
"//xla/hlo/parser:hlo_parser",
141142
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
142143
"//xla/hlo/transforms/simplifiers:hlo_element_type_converter",
143144
"//xla/service:call_graph",

xla/hlo/evaluator/hlo_evaluator.cc

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ limitations under the License.
4040
#include "absl/container/flat_hash_map.h"
4141
#include "absl/container/inlined_vector.h"
4242
#include "absl/functional/function_ref.h"
43+
#include "absl/log/check.h"
4344
#include "absl/memory/memory.h"
4445
#include "absl/numeric/bits.h"
4546
#include "absl/status/status.h"
@@ -899,8 +900,10 @@ absl::StatusOr<Literal> HloEvaluator::Evaluate(
899900
const auto& computation_shape =
900901
computation.parameter_instruction(i)->shape();
901902
const auto& arg_shape = arg_literals[i]->shape();
902-
if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
903-
arg_shape)) {
903+
bool ignore_layout = !computation_shape.has_layout();
904+
if (!Shape::Equal()
905+
.IgnoreLayout(ignore_layout)
906+
.MinorToMajorOnlyInLayout()(computation_shape, arg_shape)) {
904907
return InvalidArgument(
905908
"Shape mismatch at parameter %d. Computation expected %s, but arg "
906909
"was %s.",
@@ -1290,12 +1293,30 @@ absl::Status HloEvaluator::EvaluateInternal(
12901293
}
12911294

12921295
absl::Status HloEvaluator::HandleBitcast(const HloInstruction* bitcast) {
1293-
const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
1294-
Literal result(bitcast->shape());
1296+
Shape result_shape = bitcast->shape();
1297+
1298+
// Allow effective scalars without layouts as the result is unambiguous.
1299+
if (!result_shape.has_layout() &&
1300+
ShapeUtil::IsEffectiveScalar(result_shape)) {
1301+
result_shape = LayoutUtil::GetWithDefaultLayout(result_shape);
1302+
}
1303+
1304+
// In general, we require a layout to evaluate a bitcast: this is the only
1305+
// operation where indexing is physical rather than logical.
1306+
if (!result_shape.has_layout()) {
1307+
return InvalidArgument(
1308+
"Evaluator cannot evaluate bitcast for non-scalar operand without "
1309+
"assigned layout.");
1310+
}
1311+
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result_shape));
1312+
1313+
Literal result(result_shape);
1314+
12951315
// Bitcast output is allowed to be smaller than the input if the backend-
12961316
// specific buffer sizes for the input and output are the same. Since the HLO
12971317
// evaluator doesn't have access to the backend-specific shape size function,
12981318
// assume it's OK to bitcast if output <= input.
1319+
const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
12991320
TF_RET_CHECK(operand_literal.size_bytes() >= result.size_bytes());
13001321
memcpy(result.untyped_data(), operand_literal.untyped_data(),
13011322
result.size_bytes());
@@ -1372,8 +1393,11 @@ absl::Status HloEvaluator::HandleParameter(const HloInstruction* parameter) {
13721393
#ifndef NDEBUG
13731394
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
13741395
VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
1375-
DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
1376-
input_literal->shape()))
1396+
bool check_layout = parameter->shape().has_layout();
1397+
DCHECK(Shape::Equal()
1398+
.IgnoreLayout(!check_layout)
1399+
.MinorToMajorOnlyInLayout()(parameter->shape(),
1400+
input_literal->shape()))
13771401
<< "parameter shape is: "
13781402
<< ShapeUtil::HumanStringWithLayout(parameter->shape())
13791403
<< ", but input literal shape is: "
@@ -4723,7 +4747,7 @@ absl::Status HloEvaluator::Preprocess(const HloInstruction* hlo) {
47234747
}
47244748
}
47254749
}
4726-
return ShapeUtil::ValidateShape(hlo->shape());
4750+
return ShapeUtil::ValidateShapeWithOptionalLayout(hlo->shape());
47274751
}
47284752

47294753
absl::Status HloEvaluator::Postprocess(const HloInstruction* hlo) {

xla/hlo/evaluator/hlo_evaluator_test.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ limitations under the License.
4646
#include "xla/hlo/ir/hlo_computation.h"
4747
#include "xla/hlo/ir/hlo_instruction.h"
4848
#include "xla/hlo/ir/hlo_opcode.h"
49+
#include "xla/hlo/parser/hlo_parser.h"
4950
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
5051
#include "xla/hlo/transforms/simplifiers/hlo_element_type_converter.h"
5152
#include "xla/layout_util.h"
@@ -4566,6 +4567,71 @@ ENTRY main {
45664567
}
45674568
}
45684569

4570+
TEST_P(HloEvaluatorBf16Test, BitcastWithoutLayout) {
4571+
const absl::string_view hlo_text_base = R"(
4572+
HloModule Bitcast
4573+
4574+
ENTRY main {
4575+
param = %s[2,4] parameter(0)
4576+
ROOT bitcast = %s[4,2,1] bitcast(%s[2,4] param)
4577+
}
4578+
)";
4579+
std::string hlo_text;
4580+
Literal arg;
4581+
if (use_bfloat16_) {
4582+
hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
4583+
arg = LiteralUtil::CreateR2<bfloat16>(
4584+
{{bfloat16(1), bfloat16(2), bfloat16(3), bfloat16(4)},
4585+
{bfloat16(5), bfloat16(6), bfloat16(7), bfloat16(8)}});
4586+
} else {
4587+
hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
4588+
arg = LiteralUtil::CreateR2<float>({{1., 2., 3., 4.}, {5., 6., 7., 8.}});
4589+
}
4590+
4591+
HloParserOptions parser_config;
4592+
parser_config.set_fill_missing_layouts(false);
4593+
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnUnverifiedModule(
4594+
hlo_text, HloModuleConfig(), parser_config));
4595+
4596+
absl::StatusOr<Literal> actual = Evaluate({&arg});
4597+
EXPECT_FALSE(actual.ok());
4598+
EXPECT_EQ(actual.status().message(),
4599+
"Evaluator cannot evaluate bitcast for non-scalar operand without "
4600+
"assigned layout.");
4601+
}
4602+
4603+
TEST_P(HloEvaluatorBf16Test, EffectiveScalarBitcastWithoutLayout) {
4604+
const absl::string_view hlo_text_base = R"(
4605+
HloModule Bitcast
4606+
4607+
ENTRY main {
4608+
param = %s[1,1] parameter(0)
4609+
ROOT bitcast = %s[1,1,1] bitcast(%s[1,1] param)
4610+
}
4611+
)";
4612+
std::string hlo_text;
4613+
Literal arg;
4614+
if (use_bfloat16_) {
4615+
hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
4616+
arg = LiteralUtil::CreateR2<bfloat16>({{bfloat16(2)}});
4617+
} else {
4618+
hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
4619+
arg = LiteralUtil::CreateR2<float>({{2.}});
4620+
}
4621+
4622+
HloParserOptions parser_config;
4623+
parser_config.set_fill_missing_layouts(false);
4624+
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnUnverifiedModule(
4625+
hlo_text, HloModuleConfig(), parser_config));
4626+
4627+
TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
4628+
if (use_bfloat16_) {
4629+
EXPECT_TRUE(absl::c_equal(arg.data<bfloat16>(), actual.data<bfloat16>()));
4630+
} else {
4631+
EXPECT_TRUE(absl::c_equal(arg.data<float>(), actual.data<float>()));
4632+
}
4633+
}
4634+
45694635
// Check that s32 under/overflow doesn't trigger a ubsan failure.
45704636
TEST_F(HloEvaluatorTest, Int32Overflow) {
45714637
const absl::string_view hlo_text = R"(

xla/hlo/transforms/simplifiers/hlo_constant_folding_test.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,15 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
275275
ParseAndReturnVerifiedModule(kConstantFoldReduce));
276276
HloInstruction* add = (*m->computations().begin())->root_instruction();
277277
LayoutUtil::ClearLayout(add->mutable_shape());
278+
278279
HloConstantFolding const_folder;
279280
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
280-
EXPECT_FALSE(result);
281+
EXPECT_TRUE(result);
281282

282-
EXPECT_THAT(m->entry_computation()->root_instruction(),
283-
GmockMatch(m::Reduce()));
283+
EXPECT_EQ(6, m->entry_computation()
284+
->root_instruction()
285+
->literal()
286+
.GetFirstElement<int32_t>());
284287
}
285288

286289
const char* const kConstantFoldLargePad = R"(

xla/literal.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ const Shape* TryInternShape(const Shape& shape) {
137137
return &NilShape();
138138
}
139139
if (shape.IsArray() && shape.dimensions_size() == 0 && shape.is_static() &&
140-
shape.layout().tiles_size() == 0 && shape.layout().memory_space() == 0) {
140+
shape.has_layout() && shape.layout().tiles_size() == 0 &&
141+
shape.layout().memory_space() == 0 &&
142+
shape.layout().element_size_in_bits() == 0) {
141143
return &ScalarShape(shape.element_type());
142144
}
143145
return nullptr;
@@ -252,18 +254,20 @@ Literal::Literal(const Shape& shape)
252254
: Literal(shape, /*allocate_arrays=*/true) {}
253255

254256
void Literal::SetShape(const Shape& shape) {
255-
Shape shape_storage;
256-
const Shape* shape_ptr = &shape;
257-
if (shape.IsArray() && LayoutUtil::HasCustomElementSizeInBits(shape)) {
258-
shape_storage = shape;
259-
shape_storage.mutable_layout()->set_element_size_in_bits(0);
260-
shape_ptr = &shape_storage;
261-
}
262-
if (const Shape* intered_shape_ptr = TryInternShape(*shape_ptr)) {
257+
if (const Shape* intered_shape_ptr = TryInternShape(shape)) {
263258
shape_ = intered_shape_ptr;
264-
} else {
265-
shape_ = std::make_unique<Shape>(*shape_ptr);
259+
return;
260+
}
261+
auto owning_shape_ptr = std::make_unique<Shape>(shape);
262+
if (owning_shape_ptr->IsArray() && !owning_shape_ptr->has_layout()) {
263+
*owning_shape_ptr->mutable_layout() =
264+
LayoutUtil::GetDefaultLayoutForShape(*owning_shape_ptr);
265+
}
266+
if (owning_shape_ptr->IsArray() &&
267+
LayoutUtil::HasCustomElementSizeInBits(*owning_shape_ptr)) {
268+
owning_shape_ptr->mutable_layout()->set_element_size_in_bits(0);
266269
}
270+
shape_ = std::move(owning_shape_ptr);
267271
}
268272

269273
void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,

xla/service/collective_pipeliner_test.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ ENTRY entry {
191191
EXPECT_EQ(get_tuple_index->tuple_index(), 3);
192192
}
193193

194-
TEST_F(CollectivePipelinerTest, MinimalCase) {
194+
TEST_F(CollectivePipelinerTest, MinimalCaseWithoutDefaultLayouts) {
195195
constexpr absl::string_view hlo_string = R"(
196196
HloModule module
197197
@@ -235,8 +235,10 @@ TEST_F(CollectivePipelinerTest, MinimalCase) {
235235
ROOT dst_data = bf16[3,8,128] get-tuple-element(while), index=1
236236
}
237237
)";
238-
TF_ASSERT_OK_AND_ASSIGN(auto module,
239-
ParseAndReturnUnverifiedModule(hlo_string, config_));
238+
HloParserOptions parser_config;
239+
parser_config.set_fill_missing_layouts(false);
240+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
241+
hlo_string, config_, parser_config));
240242
EXPECT_THAT(RunOptimizer(module.get(), /*last_run=*/true),
241243
IsOkAndHolds(true));
242244

xla/shape.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ class Shape {
257257

258258
bool operator()(const Shape& lhs, const Shape& rhs);
259259

260-
Equal& IgnoreLayout() {
261-
ignore_layout_ = true;
260+
Equal& IgnoreLayout(bool ignore_layout = true) {
261+
ignore_layout_ = ignore_layout;
262262
return *this;
263263
}
264264
Equal& IgnoreTilesInLayout() {

0 commit comments

Comments
 (0)