Skip to content

Commit 319364c

Browse files
author
Lu Teng
committed
Fix forward MHA accuracy error.
1 parent 8013f6a commit 319364c

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

third_party/openxla.patch

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,7 +1323,7 @@ index e9cb21b9fa..1ba8c60b50 100644
13231323
MakeGetTupleElementHlo(new_conv, 0));
13241324
TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
13251325
diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1326-
index f03fe4f0fa..646883b3e9 100644
1326+
index f03fe4f0fa..468fa5c6dd 100644
13271327
--- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc
13281328
+++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc
13291329
@@ -234,12 +234,14 @@ auto GetUnfusedReduceMaxSumSoftmaxPattern(
@@ -1382,7 +1382,23 @@ index f03fe4f0fa..646883b3e9 100644
13821382
return is_flash_attention;
13831383
}
13841384

1385-
@@ -676,6 +684,12 @@ MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
1385+
@@ -621,6 +629,7 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
1386+
has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget
1387+
: kCudnnfMHAScaleBiasSoftmaxCallTarget;
1388+
match_result.is_causal_mask |= IsCausalMaskPattern(bias);
1389+
+#if !TENSORFLOW_USE_SYCL
1390+
if (!match_result.is_causal_mask &&
1391+
bias->opcode() == HloOpcode::kBroadcast) {
1392+
// we can take the bias before broadcast
1393+
@@ -640,6 +649,7 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
1394+
bias_bc));
1395+
}
1396+
}
1397+
+#endif
1398+
match_result.matched_bias = bias;
1399+
match_result.has_match = true;
1400+
} else {
1401+
@@ -676,6 +686,12 @@ MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
13861402
continue;
13871403
}
13881404
has_dropout = match_result.matched_dropout_rate > 0.0;
@@ -1395,15 +1411,15 @@ index f03fe4f0fa..646883b3e9 100644
13951411
match_result = MatchBmm1UnfusedBiasSoftmaxBmm2(
13961412
match_result, match_result.matched_softmax_input, has_dropout);
13971413
if (match_result.has_match) {
1398-
@@ -1087,6 +1101,7 @@ absl::StatusOr<bool> IsMHABlockSupported(
1414+
@@ -1087,6 +1103,7 @@ absl::StatusOr<bool> IsMHABlockSupported(
13991415
TF_ASSIGN_OR_RETURN(
14001416
bool is_flash_attention,
14011417
IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version));
14021418
+#if !TENSORFLOW_USE_SYCL
14031419
if (is_flash_attention) {
14041420
if (is_causal_mask) {
14051421
// if bias is causal mask, needs to remove bias from name
1406-
@@ -1098,6 +1113,11 @@ absl::StatusOr<bool> IsMHABlockSupported(
1422+
@@ -1098,6 +1115,11 @@ absl::StatusOr<bool> IsMHABlockSupported(
14071423
}
14081424
}
14091425
return is_flash_attention;
@@ -1415,31 +1431,31 @@ index f03fe4f0fa..646883b3e9 100644
14151431
}
14161432

14171433
absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
1418-
@@ -1627,6 +1647,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1434+
@@ -1627,6 +1649,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
14191435
comp->parent()->config().debug_options();
14201436
const se::dnn::VersionInfo cudnn_version =
14211437
GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
14221438
+#if !TENSORFLOW_USE_SYCL
14231439
#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
14241440
// CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
14251441
// Some cuDNN versions work with CUDA 11, but it is impractical for us to
1426-
@@ -1639,6 +1660,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1442+
@@ -1639,6 +1662,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
14271443
stream_executor::dnn::VersionInfo(8, 9, 4))) {
14281444
return false;
14291445
}
14301446
+#endif // !TENSORFLOW_USE_SYCL
14311447
for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
14321448
bool v_transposed = false;
14331449
bool changed = false;
1434-
@@ -1721,6 +1743,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1450+
@@ -1721,6 +1745,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
14351451
matched_result.need_canonicalization));
14361452
continue;
14371453
}
14381454
+#if !TENSORFLOW_USE_SYCL
14391455
if (matched_bwd_result.matched_dbias &&
14401456
!(compute_capability_.IsAtLeastHopper() &&
14411457
compute_capability_.minor == 0 &&
1442-
@@ -1734,6 +1757,17 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1458+
@@ -1734,6 +1759,17 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
14431459
matched_result.need_canonicalization));
14441460
continue;
14451461
}

0 commit comments

Comments
 (0)