Skip to content

Commit 3ad6339

Browse files
committed
fix
1 parent 2ecb424 commit 3ad6339

File tree

1 file changed

+9
-8
lines changed
  • src/infiniop/ops/cross_entropy_loss_backward/cuda

1 file changed

+9
-8
lines changed

src/infiniop/ops/cross_entropy_loss_backward/cuda/kernel.cuh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,24 @@ public:
1111
__device__ __forceinline__ T operator()(const T &probs, const T &target, const size_t batch_size) const {
1212
// grad_logits = (probs - target) / batch_size (reduction='mean')
1313
T diff;
14-
T scale = static_cast<T>(1.0) / static_cast<T>(batch_size);
14+
T scale;
1515

16-
if constexpr (std::is_same_v<T, half2>) {
17-
diff = __hsub2(probs, target);
18-
return __hmul2(diff, __float2half2_rn(static_cast<float>(scale)));
19-
} else if constexpr (std::is_same_v<T, half>) {
16+
if constexpr (std::is_same_v<T, half>) {
2017
diff = __hsub(probs, target);
21-
return __hmul(diff, __float2half(static_cast<float>(scale)));
18+
scale = static_cast<half>(1.0) / __float2half(static_cast<float>(batch_size));
19+
return __hmul(diff, scale);
2220
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
2321
diff = __hsub(probs, target);
24-
return __hmul(diff, __float2bfloat16(static_cast<float>(scale)));
22+
scale = static_cast<cuda_bfloat16>(1.0) / __float2bfloat16(static_cast<float>(batch_size));
23+
return __hmul(diff, scale);
2524
} else if constexpr (std::is_same_v<T, float>) {
2625
diff = __fsub_rd(probs, target);
27-
return __fmul_rd(diff, static_cast<float>(scale));
26+
scale = 1.0 / batch_size;
27+
return __fmul_rd(diff, scale);
2828
} else {
2929
// fallback for other types (double, etc.)
3030
diff = probs - target;
31+
scale = 1.0 / batch_size;
3132
return diff * scale;
3233
}
3334
}

0 commit comments

Comments
 (0)