Skip to content

Commit e9947dd

Browse files
[XLA:GPU] Do not compute suggested combiner threshold if there are no pipelined collectives in IR.
PiperOrigin-RevId: 702266560
1 parent f0ca2e2 commit e9947dd

10 files changed

+387
-2
lines changed

xla/service/gpu/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3162,9 +3162,9 @@ xla_cc_test(
31623162
deps = [
31633163
":gpu_all_gather_combiner",
31643164
"//xla/hlo/ir:hlo",
3165+
"//xla/hlo/testlib:filecheck",
31653166
"//xla/service:collective_utils",
31663167
"//xla/stream_executor:device_description",
3167-
"//xla/tests:filecheck",
31683168
"//xla/tests:hlo_test_base",
31693169
"@com_google_absl//absl/log",
31703170
"@com_google_absl//absl/strings:string_view",

xla/service/gpu/all_gather_combiner.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ absl::StatusOr<bool> GpuAllGatherCombiner::Run(
6868
return AllGatherCombiner::Run(module, execution_threads);
6969
}
7070

71+
// If there are no pipelined instructions in the IR, the optimizations below
72+
// do not kick in anyway.
73+
// Exit early so we do not perform expensive scheduling dry run below.
74+
if (!ContainsPipelinedInstruction(*module)) {
75+
return AllGatherCombiner::Run(module, execution_threads);
76+
}
77+
7178
// Combine as much as possible for pipelined collectives.
7279
int previous_combiner_threshold = combine_threshold_in_bytes_;
7380
combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold(

xla/service/gpu/all_gather_combiner_test.cc

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ limitations under the License.
1919
#include "absl/log/log.h"
2020
#include "absl/strings/string_view.h"
2121
#include "xla/hlo/ir/hlo_instruction.h"
22+
#include "xla/hlo/testlib/filecheck.h"
2223
#include "xla/service/collective_utils.h"
2324
#include "xla/stream_executor/device_description.h"
24-
#include "xla/tests/filecheck.h"
2525
#include "xla/tests/hlo_test_base.h"
2626
#include "tsl/platform/statusor.h"
2727
#include "tsl/platform/test.h"
@@ -137,6 +137,106 @@ ENTRY entry {
137137
kExpected));
138138
}
139139

140+
TEST_F(GpuAllGatherCombinerTest,
141+
CombinesNonPipelinedCollectivesWithAFallbackCombiner) {
142+
// The IR is the minimal valid example of a while loop with AG inside.
143+
// All collectives are not pipelined.
144+
constexpr absl::string_view kHloString = R"(
145+
HloModule module
146+
147+
add {
148+
lhs = bf16[] parameter(0)
149+
rhs = bf16[] parameter(1)
150+
ROOT add = bf16[] add(lhs, rhs)
151+
}
152+
153+
while_cond {
154+
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
155+
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
156+
gte = s32[] get-tuple-element(param), index=0
157+
constant.1 = s32[] constant(8)
158+
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
159+
}
160+
161+
while_body {
162+
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
163+
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
164+
param.0 = s32[] get-tuple-element(param), index=0
165+
param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1
166+
param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2
167+
param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3
168+
param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4
169+
param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5
170+
param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6
171+
param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7
172+
zero = bf16[] constant(0)
173+
one = s32[] constant(1)
174+
it = s32[] add(param.0, one)
175+
ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), dimensions={0}
176+
ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), dimensions={0}
177+
ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), dimensions={0}
178+
ag.nonpipelined.3 = bf16[6,8,128] all-gather(param.nonpipelined.3),
179+
dimensions={0}
180+
ag.nonpipelined.4 = bf16[6,8,128] all-gather(param.nonpipelined.4),
181+
dimensions={0}
182+
ag.nonpipelined.6 = bf16[6,8,128] all-gather(param.nonpipelined.5),
183+
dimensions={0}
184+
ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, ag.nonpipelined.3, ag.nonpipelined.4, ag.nonpipelined.6, param.7)
185+
}
186+
187+
ENTRY entry {
188+
c0 = s32[] constant(0)
189+
p0 = bf16[6,8,128] parameter(0)
190+
p1 = bf16[3,1,2,128] parameter(1)
191+
tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1)
192+
while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
193+
ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1
194+
}
195+
)";
196+
auto config =
197+
GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2);
198+
199+
DeviceDescription device_info;
200+
// Combine at most 2 collectives.
201+
int collective_size = 2 * 6 * 8 * 128;
202+
int threshold_bytes = 2 * collective_size;
203+
int current_peak_mem = 90604;
204+
int pointer_size = 4;
205+
device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4);
206+
207+
TF_ASSERT_OK_AND_ASSIGN(auto module,
208+
ParseAndReturnVerifiedModule(kHloString, config));
209+
TF_ASSERT_OK_AND_ASSIGN(
210+
bool changed,
211+
GpuAllGatherCombiner(
212+
device_info, /*default_combine_threshold_in_bytes=*/
213+
kDefaultAllGatherCombineThreshold,
214+
/*combine_threshold_in_bytes=*/kDefaultAllGatherCombineThreshold,
215+
/*combine_threshold_count=*/256,
216+
/*combine_by_dim=*/false,
217+
/*combine_different_dtypes=*/true, pointer_size)
218+
.Run(module.get()));
219+
220+
VLOG(1) << module->ToString();
221+
EXPECT_TRUE(changed);
222+
const absl::string_view kExpected = R"(
223+
// CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1
224+
// CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2
225+
// CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3
226+
// CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4
227+
// CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5
228+
// CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6
229+
// CHECK: all-gather(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]
230+
// CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]])
231+
)";
232+
EXPECT_TRUE(*RunFileCheck(
233+
module->ToString(HloPrintOptions()
234+
.set_print_operand_shape(false)
235+
.set_print_result_shape(false)
236+
.set_print_operand_index_annotation_interval(10)),
237+
kExpected));
238+
}
239+
140240
TEST_F(GpuAllGatherCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) {
141241
// The IR is the minimal valid example of a while loop with AG inside. Three
142242
// are annotated as pipelined and three are not. Various configurations of the

xla/service/gpu/all_reduce_combiner.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ absl::StatusOr<bool> GpuAllReduceCombiner::Run(
6666
return AllReduceCombiner::Run(module, execution_threads);
6767
}
6868

69+
// If there are no pipelined instructions in the IR, the optimizations below
70+
// do not kick in anyway.
71+
// Exit early so we do not perform expensive scheduling dry run below.
72+
if (!ContainsPipelinedInstruction(*module)) {
73+
return AllReduceCombiner::Run(module, execution_threads);
74+
}
75+
6976
// Combine as much as possible for pipelined collectives.
7077
int previous_combiner_threshold = combine_threshold_in_bytes_;
7178
combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold(

xla/service/gpu/all_reduce_combiner_test.cc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,101 @@ ENTRY entry {
137137
kExpected));
138138
}
139139

140+
TEST_F(GpuAllReduceCombinerTest,
141+
CombinesNonPipelinedCollectivesWithAFallbackCombiner) {
142+
// The IR is the minimal valid example of a while loop with RS inside.
143+
// All collectives are not pipelined.
144+
constexpr absl::string_view kHloString = R"(
145+
HloModule module
146+
147+
add {
148+
lhs = bf16[] parameter(0)
149+
rhs = bf16[] parameter(1)
150+
ROOT add = bf16[] add(lhs, rhs)
151+
}
152+
153+
while_cond {
154+
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
155+
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
156+
gte = s32[] get-tuple-element(param), index=0
157+
constant.1 = s32[] constant(8)
158+
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
159+
}
160+
161+
while_body {
162+
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
163+
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
164+
param.0 = s32[] get-tuple-element(param), index=0
165+
param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1
166+
param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2
167+
param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3
168+
param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4
169+
param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5
170+
param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6
171+
param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7
172+
zero = bf16[] constant(0)
173+
one = s32[] constant(1)
174+
it = s32[] add(param.0, one)
175+
ar.nonpipelined.0 = bf16[6,8,128] all-reduce(param.nonpipelined.0),
176+
to_apply=add
177+
ar.nonpipelined.1 = bf16[6,8,128] all-reduce(param.nonpipelined.1),
178+
to_apply=add
179+
ar.nonpipelined.2 = bf16[6,8,128] all-reduce(param.nonpipelined.2),
180+
to_apply=add
181+
ar.nonpipelined.3 = bf16[6,8,128] all-reduce(param.nonpipelined.3),
182+
to_apply=add
183+
ar.nonpipelined.4 = bf16[6,8,128] all-reduce(param.nonpipelined.4),
184+
to_apply=add
185+
ar.nonpipelined.5 = bf16[6,8,128] all-reduce(param.nonpipelined.5),
186+
to_apply=add
187+
ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ar.nonpipelined.0, ar.nonpipelined.1, ar.nonpipelined.2, ar.nonpipelined.3, ar.nonpipelined.4, ar.nonpipelined.5, param.7)
188+
}
189+
190+
ENTRY entry {
191+
c0 = s32[] constant(0)
192+
p0 = bf16[6,8,128] parameter(0)
193+
p1 = bf16[3,1,2,128] parameter(1)
194+
tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1)
195+
while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
196+
ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1
197+
}
198+
)";
199+
auto config =
200+
GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2);
201+
DeviceDescription device_info;
202+
int pointer_size = 4;
203+
204+
TF_ASSERT_OK_AND_ASSIGN(auto module,
205+
ParseAndReturnVerifiedModule(kHloString, config));
206+
TF_ASSERT_OK_AND_ASSIGN(
207+
bool changed,
208+
GpuAllReduceCombiner(
209+
device_info, /*default_combine_threshold_in_bytes=*/
210+
kDefaultAllReduceCombineThreshold,
211+
/*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold,
212+
/*combine_threshold_count=*/256, pointer_size)
213+
.Run(module.get()));
214+
215+
VLOG(1) << module->ToString();
216+
EXPECT_TRUE(changed);
217+
const absl::string_view kExpected = R"(
218+
// CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1
219+
// CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2
220+
// CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3
221+
// CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4
222+
// CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5
223+
// CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6
224+
// CHECK: all-reduce(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]
225+
// CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]])
226+
)";
227+
EXPECT_TRUE(*RunFileCheck(
228+
module->ToString(HloPrintOptions()
229+
.set_print_operand_shape(false)
230+
.set_print_result_shape(false)
231+
.set_print_operand_index_annotation_interval(10)),
232+
kExpected));
233+
}
234+
140235
TEST_F(GpuAllReduceCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) {
141236
// The IR is the minimal valid example of a while loop with AR inside. Three
142237
// are annotated as pipelined and three are not. Various configurations of the

xla/service/gpu/gpu_collective_combiner_utils.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,20 @@ absl::Status AppendPipelinedInstruction(HloInstruction* instr) {
8080
return instr->set_backend_config(config);
8181
}
8282

83+
bool ContainsPipelinedInstruction(const HloModule& module) {
84+
for (const HloComputation* computation : module.computations()) {
85+
for (const HloInstruction* instr : computation->instructions()) {
86+
auto backend_config = instr->backend_config<GpuBackendConfig>();
87+
if (!backend_config.ok()) {
88+
VLOG(2) << "Cannot read backend config for: " << instr->ToString();
89+
continue;
90+
}
91+
if (backend_config->collective_backend_config().is_pipelined()) {
92+
return true;
93+
}
94+
}
95+
}
96+
return false;
97+
}
98+
8399
} // namespace xla::gpu

xla/service/gpu/gpu_collective_combiner_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ int64_t ComputeSuggestedCombinerThreshold(
4646
// this.
4747
absl::Status AppendPipelinedInstruction(HloInstruction* instr);
4848

49+
// Returns true if module contains any pipelined instruction. False otherwise.
50+
bool ContainsPipelinedInstruction(const HloModule& module);
51+
4952
} // namespace xla::gpu
5053

5154
#endif // XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_

xla/service/gpu/gpu_collective_combiner_utils_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,5 +504,59 @@ TEST_F(CollectiveCombinerUtilsTest,
504504
});
505505
}
506506

507+
TEST_F(CollectiveCombinerUtilsTest,
508+
ContainsPipelinedInstructionReturnsTrueForPipelinedInstructions) {
509+
// The IR is the minimal valid example of a while loop with AR inside. Three
510+
// are annotated as pipelined and three are not. Various configurations of the
511+
// combiner are tested to ensure the expected behaviour.
512+
constexpr absl::string_view kHloText = R"(
513+
HloModule module
514+
515+
add {
516+
lhs = bf16[] parameter(0)
517+
rhs = bf16[] parameter(1)
518+
ROOT add = bf16[] add(lhs, rhs)
519+
}
520+
521+
ENTRY entry {
522+
p0 = bf16[1] parameter(0)
523+
ROOT ar.pipelined.1 = bf16[1] all-reduce(p0),
524+
to_apply=add,
525+
backend_config={"collective_backend_config": {"is_pipelined": true}}
526+
}
527+
)";
528+
529+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
530+
EXPECT_TRUE(ContainsPipelinedInstruction(*module));
531+
}
532+
533+
TEST_F(CollectiveCombinerUtilsTest,
534+
ContainsPipelinedInstructionReturnsFalseForNonPipelinedInstructions) {
535+
// The IR is the minimal valid example of a while loop with AR inside. Three
536+
// are annotated as pipelined and three are not. Various configurations of the
537+
// combiner are tested to ensure the expected behaviour.
538+
constexpr absl::string_view kHloText = R"(
539+
HloModule module
540+
541+
add {
542+
lhs = bf16[] parameter(0)
543+
rhs = bf16[] parameter(1)
544+
ROOT add = bf16[] add(lhs, rhs)
545+
}
546+
547+
ENTRY entry {
548+
p0 = bf16[1] parameter(0)
549+
ar.0 = bf16[1] all-reduce(p0),
550+
to_apply=add
551+
ROOT ar.1 = bf16[1] all-reduce(ar.0),
552+
to_apply=add,
553+
backend_config={"collective_backend_config": {"is_pipelined": false}}
554+
}
555+
)";
556+
557+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
558+
EXPECT_FALSE(ContainsPipelinedInstruction(*module));
559+
}
560+
507561
} // namespace
508562
} // namespace xla::gpu

xla/service/gpu/reduce_scatter_combiner.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ absl::StatusOr<bool> GpuReduceScatterCombiner::Run(
6666
return ReduceScatterCombiner::Run(module, execution_threads);
6767
}
6868

69+
// If there are no pipelined instructions in the IR, the optimizations below
70+
// do not kick in anyway.
71+
// Exit early so we do not perform expensive scheduling dry run below.
72+
if (!ContainsPipelinedInstruction(*module)) {
73+
return ReduceScatterCombiner::Run(module, execution_threads);
74+
}
75+
6976
// Combine as much as possible for pipelined collectives.
7077
int previous_combiner_threshold = combine_threshold_in_bytes_;
7178
combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold(

0 commit comments

Comments
 (0)