Commit 57848d5
Fix geglu (#986)
## Fixes #959 for fp16
geglu comparison tests with the original torch implementation, were
passing with very loose tolerance. This PR fixes the issue for fp16.
geglu tests had loose tolerance both for fp32 and fp16. These seem to be
different bugs. Here I fix the bug affecting fp16, which affected only
the gradients for the up_proj matrix. Specifically the issue was the
recomputation of the forward inside the backward. In the original torch
implementaiton, and for fp16, the forward values *are implicitly cast to
fp16* then stored and reused. The implicit casting step was missing from
the current implementation. Note that following downcasting to fp16 I
reupcast to fp32 for computations inside the backward.
The FP16 tests now pass with a tolerance of 1e-2, which is a commonly
accepted standard. I did not benchmark performance after the bug fix,
but I expect any impact to be minimal.
## Testing Done
test_geglu.py tests are passing with the tighter tolerance for fp16 on a
1x RTX 5070
@Tcc0403
---------
Co-authored-by: Konstantinos Pitas <[email protected]>1 parent 8b4c62f commit 57848d5
2 files changed
+7
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
| 70 | + | |
70 | 71 | | |
71 | | - | |
| 72 | + | |
72 | 73 | | |
73 | 74 | | |
74 | 75 | | |
| |||
79 | 80 | | |
80 | 81 | | |
81 | 82 | | |
82 | | - | |
| 83 | + | |
83 | 84 | | |
84 | 85 | | |
85 | 86 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
39 | | - | |
40 | | - | |
| 39 | + | |
| 40 | + | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| |||
117 | 117 | | |
118 | 118 | | |
119 | 119 | | |
120 | | - | |
121 | | - | |
| 120 | + | |
| 121 | + | |
122 | 122 | | |
123 | 123 | | |
124 | 124 | | |
| |||
0 commit comments