@@ -20,12 +20,13 @@ limitations under the License.
2020#include < gtest/gtest.h>
2121#include " absl/strings/string_view.h"
2222#include " xla/hlo/ir/hlo_instruction.h"
23+ #include " xla/hlo/ir/hlo_opcode.h"
2324#include " xla/hlo/ir/hlo_schedule.h"
2425#include " xla/hlo/testlib/hlo_hardware_independent_test_base.h"
2526#include " xla/side_effect_util.h"
2627#include " xla/test_helpers.h"
28+ #include " xla/tsl/platform/statusor.h"
2729#include " xla/util.h"
28- #include " tsl/platform/statusor.h"
2930
3031namespace xla {
3132namespace {
@@ -47,9 +48,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, NonIntegerAnnotation) {
4748 )" ;
4849 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
4950 ParseAndReturnVerifiedModule (hlo_string));
50-
51+ LegalizeSchedulingAnnotations::Config config;
5152 EXPECT_IS_NOT_OK (
52- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
53+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
5354}
5455
5556TEST_F (LegalizeSchedulingAnnotationsTest, MultipleAnnotations) {
@@ -69,9 +70,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MultipleAnnotations) {
6970 )" ;
7071 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
7172 ParseAndReturnVerifiedModule (hlo_string));
72-
73+ LegalizeSchedulingAnnotations::Config config;
7374 EXPECT_IS_NOT_OK (
74- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
75+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
7576}
7677
7778TEST_F (LegalizeSchedulingAnnotationsTest, NegativeAnnotation) {
@@ -89,9 +90,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, NegativeAnnotation) {
8990 )" ;
9091 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
9192 ParseAndReturnVerifiedModule (hlo_string));
92-
93+ LegalizeSchedulingAnnotations::Config config;
9394 EXPECT_IS_NOT_OK (
94- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
95+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
9596}
9697
9798TEST_F (LegalizeSchedulingAnnotationsTest, CrossComputationAnnotation) {
@@ -129,9 +130,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, CrossComputationAnnotation) {
129130)" ;
130131 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
131132 ParseAndReturnVerifiedModule (hlo_string));
132-
133+ LegalizeSchedulingAnnotations::Config config;
133134 EXPECT_IS_NOT_OK (
134- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
135+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
135136}
136137
137138TEST_F (LegalizeSchedulingAnnotationsTest, AnnotationWithGaps) {
@@ -153,9 +154,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps) {
153154)" ;
154155 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
155156 ParseAndReturnVerifiedModule (hlo_string));
156-
157+ LegalizeSchedulingAnnotations::Config config;
157158 EXPECT_IS_NOT_OK (
158- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
159+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
159160}
160161
161162TEST_F (LegalizeSchedulingAnnotationsTest, AnnotationWithGaps2) {
@@ -177,9 +178,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps2) {
177178)" ;
178179 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
179180 ParseAndReturnVerifiedModule (hlo_string));
180-
181+ LegalizeSchedulingAnnotations::Config config;
181182 EXPECT_IS_NOT_OK (
182- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
183+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
183184}
184185
185186TEST_F (LegalizeSchedulingAnnotationsTest, MissingAnnotationInStart) {
@@ -197,9 +198,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MissingAnnotationInStart) {
197198 )" ;
198199 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
199200 ParseAndReturnVerifiedModule (hlo_string));
200-
201+ LegalizeSchedulingAnnotations::Config config;
201202 EXPECT_IS_NOT_OK (
202- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
203+ LegalizeSchedulingAnnotations (config ).Run (hlo_module.get ()).status ());
203204}
204205
205206TEST_F (LegalizeSchedulingAnnotationsTest, MoveFusedOpAnnotationToCaller) {
@@ -220,8 +221,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MoveFusedOpAnnotationToCaller) {
220221 )" ;
221222 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
222223 ParseAndReturnVerifiedModule (hlo_string));
223-
224- EXPECT_IS_OK (LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
224+ LegalizeSchedulingAnnotations::Config config;
225+ EXPECT_IS_OK (
226+ LegalizeSchedulingAnnotations (config).Run (hlo_module.get ()).status ());
225227
226228 HloInstruction* fusion = hlo_module->entry_computation ()->root_instruction ();
227229 const auto & attrs = fusion->frontend_attributes ().map ();
@@ -248,9 +250,35 @@ TEST_F(LegalizeSchedulingAnnotationsTest, FusedOpsWithDifferentAnnotationIds) {
248250 )" ;
249251 TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
250252 ParseAndReturnVerifiedModule (hlo_string));
251-
253+ LegalizeSchedulingAnnotations::Config config;
252254 EXPECT_IS_NOT_OK (
253- LegalizeSchedulingAnnotations ().Run (hlo_module.get ()).status ());
255+ LegalizeSchedulingAnnotations (config).Run (hlo_module.get ()).status ());
256+ }
257+
258+ TEST_F (LegalizeSchedulingAnnotationsTest, DropAnnotationFromBitcast) {
259+ constexpr absl::string_view hlo_string = R"(
260+ HloModule test
261+ ENTRY entry {
262+ p0 = f32[256,1024]{1,0} parameter(0)
263+ p1 = f32[16,64,256]{2,1,0} parameter(1)
264+ ags0 = (f32[256,1024]{1,0}, f32[1024,1024]{1,0}) all-gather-start(p0), replica_groups={{0,1,2,3}}, dimensions={0}, frontend_attributes={_scheduling_group_id="0"}
265+ bitcast = f32[16,64,256]{2,1,0} bitcast(p1), frontend_attributes={_scheduling_group_id="0"}
266+ agd0 = f32[1024,1024]{1,0} all-gather-done(ags0), frontend_attributes={_scheduling_group_id="0"}
267+ ROOT tuple = (f32[16,64,256]{2,1,0}, f32[1024,1024]{1,0}) tuple(bitcast, agd0)
268+ }
269+ )" ;
270+ TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> hlo_module,
271+ ParseAndReturnVerifiedModule (hlo_string));
272+ LegalizeSchedulingAnnotations::Config config;
273+ config.keep_sync_annotation = [](const HloInstruction* instr) {
274+ return instr->opcode () != HloOpcode::kBitcast ;
275+ };
276+ EXPECT_IS_OK (
277+ LegalizeSchedulingAnnotations (config).Run (hlo_module.get ()).status ());
278+ HloInstruction* bitcast =
279+ hlo_module->entry_computation ()->root_instruction ()->mutable_operand (0 );
280+ EXPECT_FALSE (
281+ bitcast->frontend_attributes ().map ().contains (kXlaSchedulingGroupIdAttr ));
254282}
255283
256284} // namespace
0 commit comments