@@ -511,13 +511,16 @@ void FlashAttnV3GradBaseKernel(
511511 dev_ctx, {(seqlen_q + kBlockM - 1 ) / kBlockM , batch_size, num_heads});
512512 dynload::fa3_bwd_params_set_dq_semaphore (params_handle,
513513 dq_semaphore.data <int >());
514+ DenseTensor dk_semaphore = phi::Empty<int32_t >(
515+ dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k});
516+ DenseTensor dv_semaphore = phi::Empty<int32_t >(
517+ dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k});
514518 if (num_heads_k != num_heads &&
515519 dynload::fa3_bwd_params_get_deterministic (params_handle)) {
516- // TODO(tridao): do we need to zero them out?
517- DenseTensor dk_semaphore = phi::Empty<int32_t >(
518- dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k});
519- DenseTensor dv_semaphore = phi::Empty<int32_t >(
520- dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k});
520+ phi::funcs::SetConstant<Context, int32_t > set_zero_dk;
521+ set_zero_dk (dev_ctx, &dk_semaphore, static_cast <int32_t >(0 ));
522+ phi::funcs::SetConstant<Context, int32_t > set_zero_dv;
523+ set_zero_dv (dev_ctx, &dv_semaphore, static_cast <int32_t >(0 ));
521524 dynload::fa3_bwd_params_set_dk_semaphore (params_handle,
522525 dk_semaphore.data <int >());
523526 dynload::fa3_bwd_params_set_dv_semaphore (params_handle,
@@ -599,11 +602,6 @@ void FlashAttnV3GradKernel(const Context &dev_ctx,
599602 0 ,
600603 common::errors::InvalidArgument (
601604 " sm_margin is not supported, please set sm_margin to 0" ));
602- PADDLE_ENFORCE_EQ (FLAGS_cudnn_deterministic,
603- false ,
604- common::errors::InvalidArgument (
605- " deterministic is not supported in flash attention 3, "
606- " please set FLAGS_cudnn_deterministic to false" ));
607605 // umiswing: fake grad tensor for FlashAttnV3GradBaseKernel
608606 DenseTensor softmax_d;
609607 DenseTensor softmax_lse_log2;
@@ -737,11 +735,6 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
737735 0 ,
738736 common::errors::InvalidArgument (
739737 " sm_margin is not supported, please set sm_margin to 0" ));
740- PADDLE_ENFORCE_EQ (FLAGS_cudnn_deterministic,
741- false ,
742- common::errors::InvalidArgument (
743- " deterministic is not supported in flash attention 3, "
744- " please set FLAGS_cudnn_deterministic to false" ));
745738
746739 PADDLE_ENFORCE_EQ (
747740 q.dims ()[q.dims ().size () - 1 ],
@@ -1437,13 +1430,17 @@ void FlashMaskV2GradBaseKernel(
14371430 dev_ctx, {(seqlen_q + kBlockM - 1 ) / kBlockM , batch_size, num_heads});
14381431 dynload::flashmaskv2_bwd_params_set_dq_semaphore (params_handle,
14391432 dq_semaphore.data <int >());
1433+ DenseTensor dk_semaphore = phi::Empty<int32_t >(
1434+ dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k});
1435+ DenseTensor dv_semaphore = phi::Empty<int32_t >(
1436+ dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k});
14401437 if (num_heads_k != num_heads &&
14411438 dynload::flashmaskv2_bwd_params_get_deterministic (params_handle)) {
1442- // TODO(tridao): do we need to zero them out?
1443- DenseTensor dk_semaphore = phi::Empty< int32_t >(
1444- dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k} );
1445- DenseTensor dv_semaphore = phi::Empty< int32_t >(
1446- dev_ctx, {(seqlen_k + kBlockN - 1 ) / kBlockN , batch_size, num_heads_k} );
1439+ // xiangrui: we need to zero them out
1440+ phi::funcs::SetConstant<Context, int32_t > set_zero_dk;
1441+ set_zero_dk ( dev_ctx, &dk_semaphore, static_cast < int32_t >( 0 ) );
1442+ phi::funcs::SetConstant<Context, int32_t > set_zero_dv;
1443+ set_zero_dv ( dev_ctx, &dv_semaphore, static_cast < int32_t >( 0 ) );
14471444 dynload::flashmaskv2_bwd_params_set_dk_semaphore (params_handle,
14481445 dk_semaphore.data <int >());
14491446 dynload::flashmaskv2_bwd_params_set_dv_semaphore (params_handle,
@@ -1573,39 +1570,40 @@ void FlashMaskV2GradKernel(
15731570 DenseTensor dq_accum;
15741571 DenseTensor dk_accum;
15751572 DenseTensor dv_accum;
1576- FlashMaskV2GradBaseKernel<T, Context>(dev_ctx,
1577- out_grad,
1578- q,
1579- k,
1580- v,
1581- out,
1582- softmax_lse,
1583- paddle::none, // dq_
1584- paddle::none, // dk_
1585- paddle::none, // dv_
1586- paddle::none,
1587- paddle::none,
1588- paddle::none,
1589- paddle::none,
1590- startend_row_indices,
1591- block_mask,
1592- 0 , // max_seqlen_q,
1593- 0 , // max_seqlen_k,
1594- softmax_scale,
1595- is_causal,
1596- -1 , // window_size_left,
1597- -1 , // window_size_right,
1598- 0 , // softcap,
1599- false , // deterministic,
1600- 0 , // sm_margin,
1601- dq,
1602- dk,
1603- dv,
1604- &softmax_d,
1605- &softmax_lse_log2,
1606- &dq_accum,
1607- &dk_accum,
1608- &dv_accum);
1573+ FlashMaskV2GradBaseKernel<T, Context>(
1574+ dev_ctx,
1575+ out_grad,
1576+ q,
1577+ k,
1578+ v,
1579+ out,
1580+ softmax_lse,
1581+ paddle::none, // dq_
1582+ paddle::none, // dk_
1583+ paddle::none, // dv_
1584+ paddle::none,
1585+ paddle::none,
1586+ paddle::none,
1587+ paddle::none,
1588+ startend_row_indices,
1589+ block_mask,
1590+ 0 , // max_seqlen_q,
1591+ 0 , // max_seqlen_k,
1592+ softmax_scale,
1593+ is_causal,
1594+ -1 , // window_size_left,
1595+ -1 , // window_size_right,
1596+ 0 , // softcap,
1597+ FLAGS_cudnn_deterministic, // deterministic,
1598+ 0 , // sm_margin,
1599+ dq,
1600+ dk,
1601+ dv,
1602+ &softmax_d,
1603+ &softmax_lse_log2,
1604+ &dq_accum,
1605+ &dk_accum,
1606+ &dv_accum);
16091607
16101608 // umiswing: some branch in upstream fa3 could have padded the head dimension
16111609 PADDLE_ENFORCE_EQ (
0 commit comments