Skip to content

Commit 647c5cb

Browse files
akuegelGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Don't allow to fuse DUS with shared operands.
DynamicUpdateSlice is an in-place operation. Therefore we cannot fuse two such ops together if they share the same operand. Also adjust the check whether the in-place emitter can be used to handle such a case. PiperOrigin-RevId: 702288316
1 parent e9947dd commit 647c5cb

File tree

4 files changed

+104
-13
lines changed

4 files changed

+104
-13
lines changed

xla/service/gpu/ir_emission_utils.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ absl::StatusOr<bool> CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
347347
user_start_indices != dus_start_indices) {
348348
return false;
349349
}
350+
} else if (user != dus &&
351+
user.opcode() == HloOpcode::kDynamicUpdateSlice) {
352+
return false;
350353
} else if (user != dus && !user.instruction().IsElementwise() &&
351354
user.opcode() != HloOpcode::kBitcast &&
352355
user.opcode() != HloOpcode::kTuple) {

xla/service/gpu/ir_emission_utils_test.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,49 @@ ENTRY main {
870870
IsOkAndHolds(true));
871871
}
872872

873+
TEST_F(
874+
IrEmissionUtilsTest,
875+
CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesMultiOutputFusionSharedParameter) { // NOLINT
876+
const char* hlo = R"(
877+
HloModule MultipleInplaceDus, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) }
878+
879+
fused_computation {
880+
p0 = bf16[10,11,12] parameter(0)
881+
p1 = bf16[1,11,12] parameter(1)
882+
p2 = bf16[1,11,12] parameter(2)
883+
p3 = s32[] parameter(3)
884+
c0 = s32[] constant(0)
885+
cmp = pred[] compare(p3, c0), direction=EQ
886+
broadcast = pred[1,11,12] broadcast(cmp), dimensions={}
887+
select = bf16[1,11,12] select(broadcast, p1, p2)
888+
dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0)
889+
dus1 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0)
890+
ROOT tuple = (bf16[10,11,12], bf16[10,11,12]) tuple(dus0, dus1)
891+
}
892+
893+
ENTRY main {
894+
p0 = bf16[10,11,12] parameter(0)
895+
p1 = bf16[1,11,12] parameter(1)
896+
p2 = bf16[1,11,12] parameter(2)
897+
p3 = s32[] parameter(3)
898+
ROOT fusion_root_multiple = (bf16[10,11,12], bf16[10,11,12]) fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
899+
}
900+
)";
901+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
902+
ParseAndReturnVerifiedModule(hlo));
903+
auto fusion = module->entry_computation()->root_instruction();
904+
BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0);
905+
BufferAllocation::Slice slice0(&alloc, 0, 10);
906+
auto adaptor = HloFusionAdaptor::ForInstruction(fusion);
907+
EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
908+
*adaptor,
909+
[&slice0](const HloInstruction*, const ShapeIndex&) {
910+
return slice0;
911+
},
912+
fusion),
913+
IsOkAndHolds(false));
914+
}
915+
873916
TEST_F(
874917
IrEmissionUtilsTest,
875918
CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesMultiOutputFusionWithTransposeBitcasts) { // NOLINT

xla/service/gpu/transforms/horizontal_loop_fusion.cc

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ bool IsConcatenationInputFusion(const HloInstruction& instr) {
155155
instr.fused_expression_root()->opcode() == HloOpcode::kConcatenate;
156156
}
157157

158+
bool IsDynamicUpdateSliceFusion(const HloInstruction* instr) {
159+
if (instr->opcode() != HloOpcode::kFusion) {
160+
return false;
161+
}
162+
auto root = instr->fused_expression_root();
163+
if (root->opcode() == HloOpcode::kTuple) {
164+
return absl::c_any_of(root->operands(), [&](const HloInstruction* operand) {
165+
return operand->opcode() == HloOpcode::kDynamicUpdateSlice;
166+
});
167+
}
168+
return root->opcode() == HloOpcode::kDynamicUpdateSlice;
169+
}
170+
158171
bool IsFusibleCandidate(const HloInstruction& instr,
159172
const se::DeviceDescription& device_description) {
160173
// For now, we do not support fusing instruction with control flow.
@@ -247,16 +260,15 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr,
247260
return true;
248261
}
249262

250-
// Returns whether any operand of `instr` is a parameter instruction that
251-
// is shared with `fusion_instrs`.
252-
bool AnyOpndIsParamSharedAmongFusions(
263+
// Returns whether any operand of `instr` is an instruction that is shared with
264+
// `fusion_instrs`.
265+
bool AnyOperandIsSharedAmongFusions(
253266
const HloInstruction* instr,
254267
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
255268
return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
256-
return opnd->opcode() == HloOpcode::kParameter &&
257-
absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
258-
return user != instr && fusion_instrs.contains(user);
259-
});
269+
return absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
270+
return user != instr && fusion_instrs.contains(user);
271+
});
260272
});
261273
}
262274

@@ -298,11 +310,11 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
298310
<< " rejects may-not-be profitable fusion instr"
299311
<< instr->ToString();
300312
continue;
301-
} else if (sliced_input_fusion_ &&
302-
AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
303-
// Don't fuse fusions whose operands are parameter instructions that are
304-
// shared among fusions because we cannot i/o alias the produced
305-
// horizontal fusion due to the concat insertion.
313+
} else if ((sliced_input_fusion_ || IsDynamicUpdateSliceFusion(instr)) &&
314+
AnyOperandIsSharedAmongFusions(instr, fusible_candidates)) {
315+
// Don't fuse fusions with at least one shared operand because we cannot
316+
// i/o alias the produced horizontal fusion due to the concat insertion
317+
// (or run into aliasing problems with DynamicUpdateSlice fusions).
306318
VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
307319
<< " rejects the fusion instr because it shares parameter with"
308320
<< " other fusion candidates, instr: " << instr->ToString();

xla/service/gpu/transforms/horizontal_loop_fusion_test.cc

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
531531

532532
TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
533533
auto module = ParseAndReturnVerifiedModule(R"(
534-
HloModule NegativeTestForDynamicUpdateSlice
534+
HloModule DynamicUpdateSlice
535535
536536
fusion.1 {
537537
p.0 = f16[5,9,10]{2,1,0} parameter(0)
@@ -576,6 +576,39 @@ TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
576576
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
577577
}
578578

579+
TEST_F(HorizontalLoopFusionTest, DontFuseDynamicUpdateSlice) {
580+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
581+
HloModule DynamicUpdateSliceFusionsShareParameter
582+
583+
fused_dynamic_update_slice {
584+
p0 = s32[3,3]{1,0} parameter(0)
585+
p1 = pred[3,2]{1,0} parameter(1)
586+
convert = s32[3,2]{1,0} convert(p1)
587+
zero = s32[] constant(0)
588+
ROOT dynamic-update-slice = s32[3,3]{1,0} dynamic-update-slice(p0, convert, zero, zero)
589+
}
590+
591+
fused_dynamic_update_slice.1 {
592+
p0 = s32[3,3]{1,0} parameter(0)
593+
p1 = pred[2,3]{1,0} parameter(1)
594+
convert = s32[2,3]{1,0} convert(p1)
595+
zero = s32[] constant(0)
596+
ROOT dynamic-update-slice = s32[3,3]{1,0} dynamic-update-slice(p0, convert, zero, zero)
597+
}
598+
599+
ENTRY main {
600+
param_0 = s32[3,3]{1,0} parameter(0)
601+
param_1 = pred[2,3]{1,0} parameter(1)
602+
param_2 = pred[3,2]{1,0} parameter(2)
603+
loop_dynamic_update_slice_fusion = s32[3,3]{1,0} fusion(param_0, param_2), kind=kLoop, calls=fused_dynamic_update_slice
604+
loop_dynamic_update_slice_fusion.1 = s32[3,3]{1,0} fusion(param_0, param_1), kind=kLoop, calls=fused_dynamic_update_slice.1
605+
ROOT tuple.11.0 = (s32[3,3]{1,0}, s32[3,3]{1,0}) tuple(loop_dynamic_update_slice_fusion.1, loop_dynamic_update_slice_fusion)
606+
}
607+
)"));
608+
EXPECT_FALSE(
609+
HorizontalLoopFusion{device_description_}.Run(module.get()).value());
610+
}
611+
579612
TEST_F(HorizontalLoopFusionTest,
580613
AllowSharedParametersWhenNotUsingConcatenation) {
581614
auto module = ParseAndReturnVerifiedModule(R"(

0 commit comments

Comments
 (0)