@@ -11,23 +11,24 @@ public:
1111 __device__ __forceinline__ T operator ()(const T &probs, const T &target, const size_t batch_size) const {
1212 // grad_logits = (probs - target) / batch_size (reduction='mean')
1313 T diff;
14- T scale = static_cast <T>( 1.0 ) / static_cast <T>(batch_size) ;
14+ T scale;
1515
16- if constexpr (std::is_same_v<T, half2>) {
17- diff = __hsub2 (probs, target);
18- return __hmul2 (diff, __float2half2_rn (static_cast <float >(scale)));
19- } else if constexpr (std::is_same_v<T, half>) {
16+ if constexpr (std::is_same_v<T, half>) {
2017 diff = __hsub (probs, target);
21- return __hmul (diff, __float2half (static_cast <float >(scale)));
18+ scale = static_cast <half>(1.0 ) / __float2half (static_cast <float >(batch_size));
19+ return __hmul (diff, scale);
2220 } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
2321 diff = __hsub (probs, target);
24- return __hmul (diff, __float2bfloat16 (static_cast <float >(scale)));
22+ scale = static_cast <cuda_bfloat16>(1.0 ) / __float2bfloat16 (static_cast <float >(batch_size));
23+ return __hmul (diff, scale);
2524 } else if constexpr (std::is_same_v<T, float >) {
2625 diff = __fsub_rd (probs, target);
27- return __fmul_rd (diff, static_cast <float >(scale));
26+ scale = 1.0 / batch_size;
27+ return __fmul_rd (diff, scale);
2828 } else {
2929 // fallback for other types (double, etc.)
3030 diff = probs - target;
31+ scale = 1.0 / batch_size;
3132 return diff * scale;
3233 }
3334 }
0 commit comments