@@ -46,6 +46,7 @@ limitations under the License.
4646#include " xla/hlo/ir/hlo_computation.h"
4747#include " xla/hlo/ir/hlo_instruction.h"
4848#include " xla/hlo/ir/hlo_opcode.h"
49+ #include " xla/hlo/parser/hlo_parser.h"
4950#include " xla/hlo/testlib/hlo_hardware_independent_test_base.h"
5051#include " xla/hlo/transforms/simplifiers/hlo_element_type_converter.h"
5152#include " xla/layout_util.h"
@@ -4566,6 +4567,71 @@ ENTRY main {
45664567 }
45674568}
45684569
4570+ TEST_P (HloEvaluatorBf16Test, BitcastWithoutLayout) {
4571+ const absl::string_view hlo_text_base = R"(
4572+ HloModule Bitcast
4573+
4574+ ENTRY main {
4575+ param = %s[2,4] parameter(0)
4576+ ROOT bitcast = %s[4,2,1] bitcast(%s[2,4] param)
4577+ }
4578+ )" ;
4579+ std::string hlo_text;
4580+ Literal arg;
4581+ if (use_bfloat16_) {
4582+ hlo_text = absl::StrFormat (hlo_text_base, " bf16" , " bf16" , " bf16" );
4583+ arg = LiteralUtil::CreateR2<bfloat16>(
4584+ {{bfloat16 (1 ), bfloat16 (2 ), bfloat16 (3 ), bfloat16 (4 )},
4585+ {bfloat16 (5 ), bfloat16 (6 ), bfloat16 (7 ), bfloat16 (8 )}});
4586+ } else {
4587+ hlo_text = absl::StrFormat (hlo_text_base, " f32" , " f32" , " f32" );
4588+ arg = LiteralUtil::CreateR2<float >({{1 ., 2 ., 3 ., 4 .}, {5 ., 6 ., 7 ., 8 .}});
4589+ }
4590+
4591+ HloParserOptions parser_config;
4592+ parser_config.set_fill_missing_layouts (false );
4593+ TF_ASSERT_OK_AND_ASSIGN (m_, ParseAndReturnUnverifiedModule (
4594+ hlo_text, HloModuleConfig (), parser_config));
4595+
4596+ absl::StatusOr<Literal> actual = Evaluate ({&arg});
4597+ EXPECT_FALSE (actual.ok ());
4598+ EXPECT_EQ (actual.status ().message (),
4599+ " Evaluator cannot evaluate bitcast for non-scalar operand without "
4600+ " assigned layout." );
4601+ }
4602+
4603+ TEST_P (HloEvaluatorBf16Test, EffectiveScalarBitcastWithoutLayout) {
4604+ const absl::string_view hlo_text_base = R"(
4605+ HloModule Bitcast
4606+
4607+ ENTRY main {
4608+ param = %s[1,1] parameter(0)
4609+ ROOT bitcast = %s[1,1,1] bitcast(%s[1,1] param)
4610+ }
4611+ )" ;
4612+ std::string hlo_text;
4613+ Literal arg;
4614+ if (use_bfloat16_) {
4615+ hlo_text = absl::StrFormat (hlo_text_base, " bf16" , " bf16" , " bf16" );
4616+ arg = LiteralUtil::CreateR2<bfloat16>({{bfloat16 (2 )}});
4617+ } else {
4618+ hlo_text = absl::StrFormat (hlo_text_base, " f32" , " f32" , " f32" );
4619+ arg = LiteralUtil::CreateR2<float >({{2 .}});
4620+ }
4621+
4622+ HloParserOptions parser_config;
4623+ parser_config.set_fill_missing_layouts (false );
4624+ TF_ASSERT_OK_AND_ASSIGN (m_, ParseAndReturnUnverifiedModule (
4625+ hlo_text, HloModuleConfig (), parser_config));
4626+
4627+ TF_ASSERT_OK_AND_ASSIGN (Literal actual, Evaluate ({&arg}));
4628+ if (use_bfloat16_) {
4629+ EXPECT_TRUE (absl::c_equal (arg.data <bfloat16>(), actual.data <bfloat16>()));
4630+ } else {
4631+ EXPECT_TRUE (absl::c_equal (arg.data <float >(), actual.data <float >()));
4632+ }
4633+ }
4634+
45694635// Check that s32 under/overflow doesn't trigger a ubsan failure.
45704636TEST_F (HloEvaluatorTest, Int32Overflow) {
45714637 const absl::string_view hlo_text = R"(
0 commit comments