@@ -1281,15 +1281,32 @@ index 4bdd18ae63..c845a910c2 100644
1281
1281
if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
1282
1282
return user_index.size() == 1 && user_index[0] == 0;
1283
1283
diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/xla/service/gpu/cudnn_fused_conv_rewriter.cc
1284
- index a1f9bc04a0..1e4651d7ae 100644
1284
+ index a1f9bc04a0..82c915811b 100644
1285
1285
--- a/xla/service/gpu/cudnn_fused_conv_rewriter.cc
1286
1286
+++ b/xla/service/gpu/cudnn_fused_conv_rewriter.cc
1287
- @@ -858,6 +858,13 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1287
+ @@ -807,6 +807,8 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1288
+ bool can_accept_bias =
1289
+ Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
1290
+ bool can_accept_side_input = conv->operand_count() < 4;
1291
+ + // Flag to tell whether the `side_input` is really accepted.
1292
+ + bool accepted_side_input = false;
1293
+
1294
+ // The addend can be fused as a bias if
1295
+ // - it is 1D broadcasted in the output feature dimension, and
1296
+ @@ -843,6 +845,7 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1297
+ CHECK_EQ(new_operands.size(), 3);
1298
+ new_operands.push_back(addend);
1299
+ config.set_side_input_scale(1);
1300
+ + accepted_side_input = true;
1301
+ } else {
1302
+ // Can't fuse; this op already has a bias and a side-input.
1303
+ continue;
1304
+ @@ -858,6 +861,13 @@ absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
1288
1305
conv->CloneWithNewOperands(conv->shape(), new_operands));
1289
1306
comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
1290
1307
TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
1291
1308
+ #if TENSORFLOW_USE_SYCL
1292
- + if (can_accept_side_input ) {
1309
+ + if (accepted_side_input ) {
1293
1310
+ xla::Cast<HloCustomCallInstruction>(new_conv)
1294
1311
+ ->set_output_to_operand_aliasing(
1295
1312
+ {{{0}, {static_cast<long>(new_operands.size()) - 1, {}}}});
0 commit comments