Skip to content

Commit 7f7f997

Browse files
seherellisGoogle-ML-Automation
authored andcommitted
[XLA:SchedulingAnnotations] Add a configuration to filter the ops so that we can keep/drop the annotations in/from certain synchronous ops.
If an annotation gap is discovered, print the respective path between the annotated ops. This is particularly useful to detect when data-dependent sync & async ops were mistakenly annotated with the same scheduling group. PiperOrigin-RevId: 708378382
1 parent ff4f8b8 commit 7f7f997

File tree

5 files changed

+109
-22
lines changed

5 files changed

+109
-22
lines changed

xla/service/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6353,6 +6353,7 @@ cc_library(
63536353
hdrs = ["legalize_scheduling_annotations.h"],
63546354
deps = [
63556355
"//xla:side_effect_util",
6356+
"//xla:util",
63566357
"//xla:xla_data_proto_cc",
63576358
"//xla/hlo/ir:hlo",
63586359
"//xla/hlo/pass:hlo_pass",
@@ -6378,9 +6379,9 @@ xla_cc_test(
63786379
"//xla:util",
63796380
"//xla/hlo/ir:hlo",
63806381
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
6382+
"//xla/tsl/platform:statusor",
63816383
"@com_google_absl//absl/strings:string_view",
63826384
"@com_google_googletest//:gtest_main",
6383-
"@tsl//tsl/platform:statusor",
63846385
],
63856386
)
63866387

xla/service/latency_hiding_scheduler_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ absl::StatusOr<bool> RunScheduler(
152152
/*convert_collective_permute=*/HloPredicateTrue};
153153
TF_ASSIGN_OR_RETURN(bool value,
154154
AsyncCollectiveCreator(std::move(config)).Run(module));
155-
TF_ASSIGN_OR_RETURN(value, LegalizeSchedulingAnnotations().Run(module));
155+
TF_ASSIGN_OR_RETURN(value, LegalizeSchedulingAnnotations(
156+
LegalizeSchedulingAnnotations::Config())
157+
.Run(module));
156158
HloCostAnalysis::ShapeSizeFunction shape_size_bytes =
157159
[&shape_size_bytes](const Shape& shape) -> int64_t {
158160
int64_t shape_size = 0;

xla/service/legalize_scheduling_annotations.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,46 @@ absl::StatusOr<int64_t> ExtractAnnotation(
5555
return annotation_id;
5656
}
5757

58+
void DropSchedulingAnnotation(HloInstruction* instr) {
59+
VLOG(2) << "Dropping annotation from " << instr->name();
60+
FrontendAttributes frontend_attributes = instr->frontend_attributes();
61+
frontend_attributes.mutable_map()->erase("_scheduling_group_id");
62+
instr->set_frontend_attributes(frontend_attributes);
63+
}
64+
65+
bool IsSupportedAsyncOp(HloInstruction* instr) {
66+
return HloPredicateIsOp<
67+
HloOpcode::kAllGatherDone, HloOpcode::kAllGatherStart,
68+
HloOpcode::kAllReduceDone, HloOpcode::kAllReduceStart,
69+
HloOpcode::kCollectivePermuteDone, HloOpcode::kCollectivePermuteStart,
70+
HloOpcode::kAsyncDone, HloOpcode::kAsyncStart, HloOpcode::kSendDone,
71+
HloOpcode::kSend, HloOpcode::kRecvDone, HloOpcode::kRecv>(instr);
72+
}
73+
74+
bool LegalizeSchedulingAnnotations::KeepSchedulingAnnotation(
75+
HloInstruction* instr) {
76+
return IsSupportedAsyncOp(instr) || config_.keep_sync_annotation(instr);
77+
}
78+
5879
absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
5980
HloModule* module,
6081
const absl::flat_hash_set<absl::string_view>& execution_threads) {
6182
absl::flat_hash_map<HloInstruction*, int64_t> annotation;
6283
absl::flat_hash_map<int64_t, HloComputation*> annotation_to_computation;
6384
absl::flat_hash_map<int64_t, std::vector<HloInstruction*>>
6485
annotation_to_instructions;
86+
// Filter the annotated ops (using config) to keep the annotations only in the
87+
// desired sync ops. Annotations in all async ops are kept.
88+
for (HloComputation* computation : module->MakeNonfusionComputations()) {
89+
for (HloInstruction* instr : computation->instructions()) {
90+
if (!instr->frontend_attributes().map().contains(
91+
"_scheduling_group_id") ||
92+
KeepSchedulingAnnotation(instr)) {
93+
continue;
94+
}
95+
DropSchedulingAnnotation(instr);
96+
}
97+
}
6598
// Find the annotated instructions and save relevant information.
6699
for (HloComputation* computation :
67100
module->MakeNonfusionComputations(execution_threads)) {
@@ -94,6 +127,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
94127
// there are some fused instructions with different annotations.
95128
for (HloComputation* computation : module->computations(execution_threads)) {
96129
if (!computation->IsFusionComputation() ||
130+
!config_.keep_sync_annotation(computation->FusionInstruction()) ||
97131
annotation.contains(computation->FusionInstruction())) {
98132
continue;
99133
}
@@ -131,6 +165,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
131165
if (annotation_to_computation.empty()) {
132166
return false;
133167
}
168+
absl::flat_hash_map<HloInstruction*, HloInstruction*> parent;
134169
for (const auto& [id, annotated_instructions] : annotation_to_instructions) {
135170
// First find the frontier nodes that are not annotated with id but use an
136171
// annotated instruction with id.
@@ -152,6 +187,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
152187
if (!visited.contains(user) &&
153188
(!annotation.contains(user) || annotation[user] != id)) {
154189
stack.push_back(user);
190+
parent[user] = instr;
155191
visited.insert(user);
156192
VLOG(2) << "Annotation group: " << id
157193
<< ", frontier using a root: " << user->name();
@@ -168,6 +204,13 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
168204
stack.pop_back();
169205
for (HloInstruction* user : instr->users()) {
170206
if (annotation.contains(user) && annotation[user] == id) {
207+
LOG(INFO) << "PATH: " << user->name();
208+
HloInstruction* current = instr;
209+
LOG(INFO) << "PATH: " << current->name();
210+
while (parent.contains(current)) {
211+
current = parent[current];
212+
LOG(INFO) << "PATH: " << current->name();
213+
}
171214
return absl::UnimplementedError(
172215
absl::StrCat("Support for annotation groups with gaps doesn't "
173216
"exist yet, annotation: ",
@@ -179,6 +222,7 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
179222
continue;
180223
}
181224
stack.push_back(user);
225+
parent[user] = instr;
182226
visited.insert(user);
183227
}
184228
}

xla/service/legalize_scheduling_annotations.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,38 @@ limitations under the License.
1616
#ifndef XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_
1717
#define XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_
1818

19+
#include <utility>
20+
1921
#include "absl/container/flat_hash_set.h"
2022
#include "absl/status/statusor.h"
2123
#include "absl/strings/string_view.h"
2224
#include "xla/hlo/ir/hlo_module.h"
2325
#include "xla/hlo/pass/hlo_pass_interface.h"
26+
#include "xla/util.h"
2427

2528
namespace xla {
2629

2730
// Legalizer pass for scheduling annotations (to be used in
2831
// LatencyHidingScheduler).
2932
class LegalizeSchedulingAnnotations : public HloModulePass {
3033
public:
31-
LegalizeSchedulingAnnotations() = default;
34+
struct Config {
35+
HloPredicate keep_sync_annotation = HloPredicateTrue;
36+
};
37+
38+
explicit LegalizeSchedulingAnnotations(Config config)
39+
: config_(std::move(config)) {}
3240
absl::string_view name() const override {
3341
return "legalize-scheduling-annotations";
3442
}
3543
using HloPassInterface::Run;
3644
absl::StatusOr<bool> Run(
3745
HloModule* module,
3846
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
47+
48+
private:
49+
bool KeepSchedulingAnnotation(HloInstruction* instr);
50+
Config config_;
3951
};
4052
} // namespace xla
4153

xla/service/legalize_scheduling_annotations_test.cc

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ limitations under the License.
2020
#include <gtest/gtest.h>
2121
#include "absl/strings/string_view.h"
2222
#include "xla/hlo/ir/hlo_instruction.h"
23+
#include "xla/hlo/ir/hlo_opcode.h"
2324
#include "xla/hlo/ir/hlo_schedule.h"
2425
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
2526
#include "xla/side_effect_util.h"
2627
#include "xla/test_helpers.h"
28+
#include "xla/tsl/platform/statusor.h"
2729
#include "xla/util.h"
28-
#include "tsl/platform/statusor.h"
2930

3031
namespace xla {
3132
namespace {
@@ -47,9 +48,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, NonIntegerAnnotation) {
4748
)";
4849
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
4950
ParseAndReturnVerifiedModule(hlo_string));
50-
51+
LegalizeSchedulingAnnotations::Config config;
5152
EXPECT_IS_NOT_OK(
52-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
53+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
5354
}
5455

5556
TEST_F(LegalizeSchedulingAnnotationsTest, MultipleAnnotations) {
@@ -69,9 +70,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MultipleAnnotations) {
6970
)";
7071
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
7172
ParseAndReturnVerifiedModule(hlo_string));
72-
73+
LegalizeSchedulingAnnotations::Config config;
7374
EXPECT_IS_NOT_OK(
74-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
75+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
7576
}
7677

7778
TEST_F(LegalizeSchedulingAnnotationsTest, NegativeAnnotation) {
@@ -89,9 +90,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, NegativeAnnotation) {
8990
)";
9091
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
9192
ParseAndReturnVerifiedModule(hlo_string));
92-
93+
LegalizeSchedulingAnnotations::Config config;
9394
EXPECT_IS_NOT_OK(
94-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
95+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
9596
}
9697

9798
TEST_F(LegalizeSchedulingAnnotationsTest, CrossComputationAnnotation) {
@@ -129,9 +130,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, CrossComputationAnnotation) {
129130
)";
130131
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
131132
ParseAndReturnVerifiedModule(hlo_string));
132-
133+
LegalizeSchedulingAnnotations::Config config;
133134
EXPECT_IS_NOT_OK(
134-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
135+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
135136
}
136137

137138
TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps) {
@@ -153,9 +154,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps) {
153154
)";
154155
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
155156
ParseAndReturnVerifiedModule(hlo_string));
156-
157+
LegalizeSchedulingAnnotations::Config config;
157158
EXPECT_IS_NOT_OK(
158-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
159+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
159160
}
160161

161162
TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps2) {
@@ -177,9 +178,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps2) {
177178
)";
178179
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
179180
ParseAndReturnVerifiedModule(hlo_string));
180-
181+
LegalizeSchedulingAnnotations::Config config;
181182
EXPECT_IS_NOT_OK(
182-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
183+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
183184
}
184185

185186
TEST_F(LegalizeSchedulingAnnotationsTest, MissingAnnotationInStart) {
@@ -197,9 +198,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MissingAnnotationInStart) {
197198
)";
198199
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
199200
ParseAndReturnVerifiedModule(hlo_string));
200-
201+
LegalizeSchedulingAnnotations::Config config;
201202
EXPECT_IS_NOT_OK(
202-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
203+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
203204
}
204205

205206
TEST_F(LegalizeSchedulingAnnotationsTest, MoveFusedOpAnnotationToCaller) {
@@ -220,8 +221,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MoveFusedOpAnnotationToCaller) {
220221
)";
221222
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
222223
ParseAndReturnVerifiedModule(hlo_string));
223-
224-
EXPECT_IS_OK(LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
224+
LegalizeSchedulingAnnotations::Config config;
225+
EXPECT_IS_OK(
226+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
225227

226228
HloInstruction* fusion = hlo_module->entry_computation()->root_instruction();
227229
const auto& attrs = fusion->frontend_attributes().map();
@@ -248,9 +250,35 @@ TEST_F(LegalizeSchedulingAnnotationsTest, FusedOpsWithDifferentAnnotationIds) {
248250
)";
249251
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
250252
ParseAndReturnVerifiedModule(hlo_string));
251-
253+
LegalizeSchedulingAnnotations::Config config;
252254
EXPECT_IS_NOT_OK(
253-
LegalizeSchedulingAnnotations().Run(hlo_module.get()).status());
255+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
256+
}
257+
258+
TEST_F(LegalizeSchedulingAnnotationsTest, DropAnnotationFromBitcast) {
259+
constexpr absl::string_view hlo_string = R"(
260+
HloModule test
261+
ENTRY entry {
262+
p0 = f32[256,1024]{1,0} parameter(0)
263+
p1 = f32[16,64,256]{2,1,0} parameter(1)
264+
ags0 = (f32[256,1024]{1,0}, f32[1024,1024]{1,0}) all-gather-start(p0), replica_groups={{0,1,2,3}}, dimensions={0}, frontend_attributes={_scheduling_group_id="0"}
265+
bitcast = f32[16,64,256]{2,1,0} bitcast(p1), frontend_attributes={_scheduling_group_id="0"}
266+
agd0 = f32[1024,1024]{1,0} all-gather-done(ags0), frontend_attributes={_scheduling_group_id="0"}
267+
ROOT tuple = (f32[16,64,256]{2,1,0}, f32[1024,1024]{1,0}) tuple(bitcast, agd0)
268+
}
269+
)";
270+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
271+
ParseAndReturnVerifiedModule(hlo_string));
272+
LegalizeSchedulingAnnotations::Config config;
273+
config.keep_sync_annotation = [](const HloInstruction* instr) {
274+
return instr->opcode() != HloOpcode::kBitcast;
275+
};
276+
EXPECT_IS_OK(
277+
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
278+
HloInstruction* bitcast =
279+
hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
280+
EXPECT_FALSE(
281+
bitcast->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr));
254282
}
255283

256284
} // namespace

0 commit comments

Comments
 (0)