@@ -1323,7 +1323,7 @@ index e9cb21b9fa..1ba8c60b50 100644
1323
1323
MakeGetTupleElementHlo(new_conv, 0));
1324
1324
TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
1325
1325
diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1326
- index f03fe4f0fa..646883b3e9 100644
1326
+ index f03fe4f0fa..468fa5c6dd 100644
1327
1327
--- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1328
1328
+++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1329
1329
@@ -234,12 +234,14 @@ auto GetUnfusedReduceMaxSumSoftmaxPattern(
@@ -1382,7 +1382,23 @@ index f03fe4f0fa..646883b3e9 100644
1382
1382
return is_flash_attention;
1383
1383
}
1384
1384
1385
- @@ -676,6 +684,12 @@ MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
1385
+ @@ -621,6 +629,7 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
1386
+ has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget
1387
+ : kCudnnfMHAScaleBiasSoftmaxCallTarget;
1388
+ match_result.is_causal_mask |= IsCausalMaskPattern(bias);
1389
+ + #if !TENSORFLOW_USE_SYCL
1390
+ if (!match_result.is_causal_mask &&
1391
+ bias->opcode() == HloOpcode::kBroadcast) {
1392
+ // we can take the bias before broadcast
1393
+ @@ -640,6 +649,7 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
1394
+ bias_bc));
1395
+ }
1396
+ }
1397
+ + #endif
1398
+ match_result.matched_bias = bias;
1399
+ match_result.has_match = true;
1400
+ } else {
1401
+ @@ -676,6 +686,12 @@ MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
1386
1402
continue;
1387
1403
}
1388
1404
has_dropout = match_result.matched_dropout_rate > 0.0;
@@ -1395,15 +1411,15 @@ index f03fe4f0fa..646883b3e9 100644
1395
1411
match_result = MatchBmm1UnfusedBiasSoftmaxBmm2(
1396
1412
match_result, match_result.matched_softmax_input, has_dropout);
1397
1413
if (match_result.has_match) {
1398
- @@ -1087,6 +1101 ,7 @@ absl::StatusOr<bool> IsMHABlockSupported(
1414
+ @@ -1087,6 +1103 ,7 @@ absl::StatusOr<bool> IsMHABlockSupported(
1399
1415
TF_ASSIGN_OR_RETURN(
1400
1416
bool is_flash_attention,
1401
1417
IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version));
1402
1418
+ #if !TENSORFLOW_USE_SYCL
1403
1419
if (is_flash_attention) {
1404
1420
if (is_causal_mask) {
1405
1421
// if bias is causal mask, needs to remove bias from name
1406
- @@ -1098,6 +1113 ,11 @@ absl::StatusOr<bool> IsMHABlockSupported(
1422
+ @@ -1098,6 +1115 ,11 @@ absl::StatusOr<bool> IsMHABlockSupported(
1407
1423
}
1408
1424
}
1409
1425
return is_flash_attention;
@@ -1415,31 +1431,31 @@ index f03fe4f0fa..646883b3e9 100644
1415
1431
}
1416
1432
1417
1433
absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
1418
- @@ -1627,6 +1647 ,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1434
+ @@ -1627,6 +1649 ,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1419
1435
comp->parent()->config().debug_options();
1420
1436
const se::dnn::VersionInfo cudnn_version =
1421
1437
GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
1422
1438
+ #if !TENSORFLOW_USE_SYCL
1423
1439
#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
1424
1440
// CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
1425
1441
// Some cuDNN versions work with CUDA 11, but it is impractical for us to
1426
- @@ -1639,6 +1660 ,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1442
+ @@ -1639,6 +1662 ,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1427
1443
stream_executor::dnn::VersionInfo(8, 9, 4))) {
1428
1444
return false;
1429
1445
}
1430
1446
+ #endif // !TENSORFLOW_USE_SYCL
1431
1447
for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
1432
1448
bool v_transposed = false;
1433
1449
bool changed = false;
1434
- @@ -1721,6 +1743 ,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1450
+ @@ -1721,6 +1745 ,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1435
1451
matched_result.need_canonicalization));
1436
1452
continue;
1437
1453
}
1438
1454
+ #if !TENSORFLOW_USE_SYCL
1439
1455
if (matched_bwd_result.matched_dbias &&
1440
1456
!(compute_capability_.IsAtLeastHopper() &&
1441
1457
compute_capability_.minor == 0 &&
1442
- @@ -1734,6 +1757 ,17 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1458
+ @@ -1734,6 +1759 ,17 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1443
1459
matched_result.need_canonicalization));
1444
1460
continue;
1445
1461
}
0 commit comments