diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 463498e8ca..66a61a387b 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -238,12 +238,21 @@ class CollectiveEpilogue< } } + bool is_beta_zero = (args.thread.beta == decltype(args.thread.beta)(0)); + bool beta_implementable = true; + + if (!is_source_supported || args.ptr_C == nullptr) { + beta_implementable = is_beta_zero; + } + if constexpr (is_source_supported) { - constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; - implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dC); - if (L > 1) { - constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; - implementable &= get<2>(args.dC) % min_batch_aligned_elements_C == 0; + if (!is_beta_zero) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dC); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dC) % min_batch_aligned_elements_C == 0; + } } } @@ -257,7 +266,11 @@ class CollectiveEpilogue< CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); } - return implementable && fusion_implementable; + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && beta_implementable && fusion_implementable; } CUTLASS_HOST_DEVICE