Skip to content

Commit 8a3c857

Browse files
author
Lu Teng
authored
[Bug fix] Fix in-place error. (#430)
1 parent 9cd7344 commit 8a3c857

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

third_party/openxla.patch

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,15 +1281,32 @@ index 4bdd18ae63..c845a910c2 100644
12811281
if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
12821282
return user_index.size() == 1 && user_index[0] == 0;
12831283
diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/xla/service/gpu/cudnn_fused_conv_rewriter.cc
1284-
index a1f9bc04a0..1e4651d7ae 100644
1284+
index a1f9bc04a0..82c915811b 100644
12851285
--- a/xla/service/gpu/cudnn_fused_conv_rewriter.cc
12861286
+++ b/xla/service/gpu/cudnn_fused_conv_rewriter.cc
1287-
@@ -858,6 +858,13 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1287+
@@ -807,6 +807,8 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1288+
bool can_accept_bias =
1289+
Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
1290+
bool can_accept_side_input = conv->operand_count() < 4;
1291+
+ // Flag to tell whether the `side_input` is really accepted.
1292+
+ bool accepted_side_input = false;
1293+
1294+
// The addend can be fused as a bias if
1295+
// - it is 1D broadcasted in the output feature dimension, and
1296+
@@ -843,6 +845,7 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1297+
CHECK_EQ(new_operands.size(), 3);
1298+
new_operands.push_back(addend);
1299+
config.set_side_input_scale(1);
1300+
+ accepted_side_input = true;
1301+
} else {
1302+
// Can't fuse; this op already has a bias and a side-input.
1303+
continue;
1304+
@@ -858,6 +861,13 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
12881305
conv->CloneWithNewOperands(conv->shape(), new_operands));
12891306
comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
12901307
TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
12911308
+#if TENSORFLOW_USE_SYCL
1292-
+ if (can_accept_side_input) {
1309+
+ if (accepted_side_input) {
12931310
+ xla::Cast<HloCustomCallInstruction>(new_conv)
12941311
+ ->set_output_to_operand_aliasing(
12951312
+ {{{0}, {static_cast<long>(new_operands.size()) - 1, {}}}});

0 commit comments

Comments
 (0)