Skip to content

Commit c5fbe36

Browse files
Use torch autograd integration for backward testing (#2123)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0b0c819 commit c5fbe36

File tree

1 file changed

+10
-36
lines changed

1 file changed

+10
-36
lines changed

tests/ext_thunder/test_unsloth_executor.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
@pytest.mark.parametrize("reduction", ["none", "mean"])
1111
def test_unsloth_cross_entropy(reduction):
1212
import thunder
13-
from thunder.core.transforms import grad
1413

1514
from extensions.thunder.unsloth.executor import unsloth_ex
1615

@@ -33,22 +32,19 @@ def foo(logits, labels):
3332
expected = foo(logits, labels)
3433
torch.testing.assert_close(actual, expected)
3534

36-
cfoo_grad = grad(cfoo)
37-
actual = cfoo_grad(logits, labels)[0]
38-
trace_str = str(thunder.last_traces(cfoo_grad)[-1])
35+
(actual_grad,) = torch.autograd.grad(actual.sum(), logits)
36+
trace_str = str(thunder.last_backward_traces(cfoo)[-1])
3937
assert "unsloth_cross_entropy_backward" in trace_str
4038
out = foo(logits, labels)
4139
assert logits.grad is None
42-
out.sum().backward()
43-
expected = logits.grad
44-
torch.testing.assert_close(actual, expected)
40+
(expected_grad,) = torch.autograd.grad(out.sum(), logits)
41+
torch.testing.assert_close(actual_grad, expected_grad)
4542

4643

4744
@pytest.mark.skip(reason="out of date")
4845
@_RunIf(min_cuda_gpus=1, thunder=True)
4946
def test_unsloth_rope():
5047
import thunder
51-
from thunder.core.transforms import grad
5248

5349
from extensions.thunder.unsloth.executor import unsloth_ex
5450

@@ -71,21 +67,14 @@ def foo(x, cos, sin):
7167
expected = foo(q, cos, sin)
7268
torch.testing.assert_close(actual, expected)
7369

74-
cfoo_grad = grad(cfoo)
75-
actual = cfoo_grad(q, cos, sin)[0]
76-
trace_str = str(thunder.last_traces(cfoo_grad)[-1])
77-
assert "unsloth_apply_rope_backward" in trace_str
78-
out = foo(q, cos, sin)
79-
assert q.grad is None
80-
out.sum().backward()
81-
expected = q.grad
82-
torch.testing.assert_close(actual, expected)
70+
(actual_grad,) = torch.autograd.grad(actual.sum(), q)
71+
(expected_grad,) = torch.autograd.grad(expected.sum(), q)
72+
torch.testing.assert_close(actual_grad, expected_grad)
8373

8474

8575
@_RunIf(min_cuda_gpus=1, thunder=True)
8676
def test_unsloth_swiglu():
8777
import thunder
88-
from thunder.core.transforms import grad
8978

9079
from extensions.thunder.unsloth.executor import ThunderLLaMAMLP, unsloth_ex
9180
from litgpt import Config
@@ -108,21 +97,14 @@ def test_unsloth_swiglu():
10897
expected = mlp(x)
10998
torch.testing.assert_close(actual, expected)
11099

111-
cmlp_grad = grad(cmlp)
112-
actual = cmlp_grad(x)[0]
113-
trace_str = str(thunder.last_traces(cmlp_grad)[-1])
114-
assert "unsloth_swiglu_backward" in trace_str
115-
out = mlp(x)
116-
assert x.grad is None
117-
out.sum().backward()
118-
expected = x.grad
119-
torch.testing.assert_close(actual, expected)
100+
(actual_grad,) = torch.autograd.grad(actual.sum(), x)
101+
(expected_grad,) = torch.autograd.grad(expected.sum(), x)
102+
torch.testing.assert_close(actual_grad, expected_grad)
120103

121104

122105
@_RunIf(min_cuda_gpus=1, thunder=True)
123106
def test_unsloth_gpt():
124107
import thunder
125-
from thunder.core.transforms import grad
126108

127109
from extensions.thunder.unsloth.executor import unsloth_ex
128110

@@ -164,11 +146,3 @@ def forward_and_loss(model, input_ids, targets):
164146
assert "unsloth_apply_rope_backward" in bwd_str
165147
assert "unsloth_swiglu" in fwd_str
166148
assert "unsloth_swiglu_backward" in bwd_str
167-
168-
cfn_grad = grad(cfn)
169-
_ = cfn_grad(model, input_ids, targets)
170-
bwd = thunder.last_traces(cfn_grad)
171-
bwd_str = bwd[-1].python()
172-
assert "unsloth_cross_entropy_backward" in bwd_str
173-
assert "unsloth_apply_rope_backward" in bwd_str
174-
assert "unsloth_swiglu_backward" in bwd_str

0 commit comments

Comments
 (0)