10
10
@pytest .mark .parametrize ("reduction" , ["none" , "mean" ])
11
11
def test_unsloth_cross_entropy (reduction ):
12
12
import thunder
13
- from thunder .core .transforms import grad
14
13
15
14
from extensions .thunder .unsloth .executor import unsloth_ex
16
15
@@ -33,22 +32,19 @@ def foo(logits, labels):
33
32
expected = foo (logits , labels )
34
33
torch .testing .assert_close (actual , expected )
35
34
36
- cfoo_grad = grad (cfoo )
37
- actual = cfoo_grad (logits , labels )[0 ]
38
- trace_str = str (thunder .last_traces (cfoo_grad )[- 1 ])
35
+ (actual_grad ,) = torch .autograd .grad (actual .sum (), logits )
36
+ trace_str = str (thunder .last_backward_traces (cfoo )[- 1 ])
39
37
assert "unsloth_cross_entropy_backward" in trace_str
40
38
out = foo (logits , labels )
41
39
assert logits .grad is None
42
- out .sum ().backward ()
43
- expected = logits .grad
44
- torch .testing .assert_close (actual , expected )
40
+ (expected_grad ,) = torch .autograd .grad (out .sum (), logits )
41
+ torch .testing .assert_close (actual_grad , expected_grad )
45
42
46
43
47
44
@pytest .mark .skip (reason = "out of date" )
48
45
@_RunIf (min_cuda_gpus = 1 , thunder = True )
49
46
def test_unsloth_rope ():
50
47
import thunder
51
- from thunder .core .transforms import grad
52
48
53
49
from extensions .thunder .unsloth .executor import unsloth_ex
54
50
@@ -71,21 +67,14 @@ def foo(x, cos, sin):
71
67
expected = foo (q , cos , sin )
72
68
torch .testing .assert_close (actual , expected )
73
69
74
- cfoo_grad = grad (cfoo )
75
- actual = cfoo_grad (q , cos , sin )[0 ]
76
- trace_str = str (thunder .last_traces (cfoo_grad )[- 1 ])
77
- assert "unsloth_apply_rope_backward" in trace_str
78
- out = foo (q , cos , sin )
79
- assert q .grad is None
80
- out .sum ().backward ()
81
- expected = q .grad
82
- torch .testing .assert_close (actual , expected )
70
+ (actual_grad ,) = torch .autograd .grad (actual .sum (), q )
71
+ (expected_grad ,) = torch .autograd .grad (expected .sum (), q )
72
+ torch .testing .assert_close (actual_grad , expected_grad )
83
73
84
74
85
75
@_RunIf (min_cuda_gpus = 1 , thunder = True )
86
76
def test_unsloth_swiglu ():
87
77
import thunder
88
- from thunder .core .transforms import grad
89
78
90
79
from extensions .thunder .unsloth .executor import ThunderLLaMAMLP , unsloth_ex
91
80
from litgpt import Config
@@ -108,21 +97,14 @@ def test_unsloth_swiglu():
108
97
expected = mlp (x )
109
98
torch .testing .assert_close (actual , expected )
110
99
111
- cmlp_grad = grad (cmlp )
112
- actual = cmlp_grad (x )[0 ]
113
- trace_str = str (thunder .last_traces (cmlp_grad )[- 1 ])
114
- assert "unsloth_swiglu_backward" in trace_str
115
- out = mlp (x )
116
- assert x .grad is None
117
- out .sum ().backward ()
118
- expected = x .grad
119
- torch .testing .assert_close (actual , expected )
100
+ (actual_grad ,) = torch .autograd .grad (actual .sum (), x )
101
+ (expected_grad ,) = torch .autograd .grad (expected .sum (), x )
102
+ torch .testing .assert_close (actual_grad , expected_grad )
120
103
121
104
122
105
@_RunIf (min_cuda_gpus = 1 , thunder = True )
123
106
def test_unsloth_gpt ():
124
107
import thunder
125
- from thunder .core .transforms import grad
126
108
127
109
from extensions .thunder .unsloth .executor import unsloth_ex
128
110
@@ -164,11 +146,3 @@ def forward_and_loss(model, input_ids, targets):
164
146
assert "unsloth_apply_rope_backward" in bwd_str
165
147
assert "unsloth_swiglu" in fwd_str
166
148
assert "unsloth_swiglu_backward" in bwd_str
167
-
168
- cfn_grad = grad (cfn )
169
- _ = cfn_grad (model , input_ids , targets )
170
- bwd = thunder .last_traces (cfn_grad )
171
- bwd_str = bwd [- 1 ].python ()
172
- assert "unsloth_cross_entropy_backward" in bwd_str
173
- assert "unsloth_apply_rope_backward" in bwd_str
174
- assert "unsloth_swiglu_backward" in bwd_str
0 commit comments