Skip to content

Commit b708f79

Browse files
authored
[NPU]: update the native KLDivLoss implementation for comparison. (eg.)test_jsd.py (#1032)
## Summary This PR modifies the NPU test reference for KLDivLoss. Since the native NPU KLDivLoss operator does not support gradients w.r.t. the target [#1021 ](#1021) it caused failures in test_jsd.py (where input and target are swapped when beta != 0). To resolve this, I replaced the native operator usage with a custom implementation using basic math operations. This allows correct gradient computation for the target and aligns the x1.grad results with the Triton kernel implementation. ## Testing Done I tested test_jsd,test_fused_linear_jsd by following method and all cases passed: pytest -v test/transformers/test_jsd.py pytest -v test/transformers/test_fused_linear_jsd.py Hardware Type: Ascend NPU 910B3 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 559e9a1 commit b708f79

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

test/transformers/test_jsd.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,31 @@
1818
set_seed(42)
1919

2020

21+
class NPUKLDivLoss(torch.nn.Module):
22+
"""
23+
A custom KLDivLoss for NPU.
24+
25+
On NPU devices, torch.nn.KLDivLoss does not compute gradients with respect to the target.
26+
This leads to incorrect gradient computation when the target depends on the input,
27+
such as in JSD or reverse KLDiv.
28+
See https://github.com/linkedin/Liger-Kernel/issues/1021 for more details.
29+
"""
30+
31+
def __init__(self, reduction="none", log_target=True):
32+
super().__init__()
33+
34+
def forward(self, input, target):
35+
original_dtype = input.dtype
36+
37+
if input.dtype in [torch.float16, torch.bfloat16]:
38+
input = input.float()
39+
target = target.float()
40+
41+
loss = torch.exp(target) * (target - input)
42+
43+
return loss.to(original_dtype)
44+
45+
2146
class JSD(torch.nn.Module):
2247
def __init__(
2348
self,
@@ -26,7 +51,10 @@ def __init__(
2651
dtype: torch.dtype = torch.float,
2752
):
2853
super(JSD, self).__init__()
29-
self.kl = KLDivLoss(reduction="none", log_target=True)
54+
if device == "npu":
55+
self.kl = NPUKLDivLoss(reduction="none", log_target=True)
56+
else:
57+
self.kl = KLDivLoss(reduction="none", log_target=True)
3058
self.beta = beta
3159
self.ignore_index = ignore_index
3260
self.dtype = dtype

0 commit comments

Comments
 (0)