@@ -19,9 +19,9 @@ limitations under the License.
1919#include " absl/log/log.h"
2020#include " absl/strings/string_view.h"
2121#include " xla/hlo/ir/hlo_instruction.h"
22+ #include " xla/hlo/testlib/filecheck.h"
2223#include " xla/service/collective_utils.h"
2324#include " xla/stream_executor/device_description.h"
24- #include " xla/tests/filecheck.h"
2525#include " xla/tests/hlo_test_base.h"
2626#include " tsl/platform/statusor.h"
2727#include " tsl/platform/test.h"
@@ -137,6 +137,106 @@ ENTRY entry {
137137 kExpected ));
138138}
139139
140+ TEST_F (GpuAllGatherCombinerTest,
141+ CombinesNonPipelinedCollectivesWithAFallbackCombiner) {
142+ // The IR is the minimal valid example of a while loop with AG inside.
143+ // All collectives are not pipelined.
144+ constexpr absl::string_view kHloString = R"(
145+ HloModule module
146+
147+ add {
148+ lhs = bf16[] parameter(0)
149+ rhs = bf16[] parameter(1)
150+ ROOT add = bf16[] add(lhs, rhs)
151+ }
152+
153+ while_cond {
154+ param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
155+ bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
156+ gte = s32[] get-tuple-element(param), index=0
157+ constant.1 = s32[] constant(8)
158+ ROOT cmp = pred[] compare(gte, constant.1), direction=LT
159+ }
160+
161+ while_body {
162+ param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
163+ bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
164+ param.0 = s32[] get-tuple-element(param), index=0
165+ param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1
166+ param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2
167+ param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3
168+ param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4
169+ param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5
170+ param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6
171+ param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7
172+ zero = bf16[] constant(0)
173+ one = s32[] constant(1)
174+ it = s32[] add(param.0, one)
175+ ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), dimensions={0}
176+ ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), dimensions={0}
177+ ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), dimensions={0}
178+ ag.nonpipelined.3 = bf16[6,8,128] all-gather(param.nonpipelined.3),
179+ dimensions={0}
180+ ag.nonpipelined.4 = bf16[6,8,128] all-gather(param.nonpipelined.4),
181+ dimensions={0}
182+ ag.nonpipelined.6 = bf16[6,8,128] all-gather(param.nonpipelined.5),
183+ dimensions={0}
184+ ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, ag.nonpipelined.3, ag.nonpipelined.4, ag.nonpipelined.6, param.7)
185+ }
186+
187+ ENTRY entry {
188+ c0 = s32[] constant(0)
189+ p0 = bf16[6,8,128] parameter(0)
190+ p1 = bf16[3,1,2,128] parameter(1)
191+ tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1)
192+ while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
193+ ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1
194+ }
195+ )" ;
196+ auto config =
197+ GetModuleConfigForTest (/* replica_count=*/ 1 , /* num_partitions=*/ 2 );
198+
199+ DeviceDescription device_info;
200+ // Combine at most 2 collectives.
201+ int collective_size = 2 * 6 * 8 * 128 ;
202+ int threshold_bytes = 2 * collective_size;
203+ int current_peak_mem = 90604 ;
204+ int pointer_size = 4 ;
205+ device_info.set_device_memory_size (current_peak_mem + threshold_bytes * 4 );
206+
207+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
208+ ParseAndReturnVerifiedModule (kHloString , config));
209+ TF_ASSERT_OK_AND_ASSIGN (
210+ bool changed,
211+ GpuAllGatherCombiner (
212+ device_info, /* default_combine_threshold_in_bytes=*/
213+ kDefaultAllGatherCombineThreshold ,
214+ /* combine_threshold_in_bytes=*/ kDefaultAllGatherCombineThreshold ,
215+ /* combine_threshold_count=*/ 256 ,
216+ /* combine_by_dim=*/ false ,
217+ /* combine_different_dtypes=*/ true , pointer_size)
218+ .Run (module .get ()));
219+
220+ VLOG (1 ) << module ->ToString ();
221+ EXPECT_TRUE (changed);
222+ const absl::string_view kExpected = R"(
223+ // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1
224+ // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2
225+ // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3
226+ // CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4
227+ // CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5
228+ // CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6
229+ // CHECK: all-gather(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]
230+ // CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]])
231+ )" ;
232+ EXPECT_TRUE (*RunFileCheck (
233+ module ->ToString (HloPrintOptions ()
234+ .set_print_operand_shape (false )
235+ .set_print_result_shape (false )
236+ .set_print_operand_index_annotation_interval (10 )),
237+ kExpected ));
238+ }
239+
140240TEST_F (GpuAllGatherCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) {
141241 // The IR is the minimal valid example of a while loop with AG inside. Three
142242 // are annotated as pipelined and three are not. Various configurations of the
0 commit comments