Skip to content

Commit 57848d5

Browse files
konstantinos-pKonstantinos Pitas
andauthored
Fix geglu (#986)
## Fixes #959 for fp16 geglu comparison tests with the original torch implementation, were passing with very loose tolerance. This PR fixes the issue for fp16. geglu tests had loose tolerance both for fp32 and fp16. These seem to be different bugs. Here I fix the bug affecting fp16, which affected only the gradients for the up_proj matrix. Specifically the issue was the recomputation of the forward inside the backward. In the original torch implementaiton, and for fp16, the forward values *are implicitly cast to fp16* then stored and reused. The implicit casting step was missing from the current implementation. Note that following downcasting to fp16 I reupcast to fp32 for computations inside the backward. The FP16 tests now pass with a tolerance of 1e-2, which is a commonly accepted standard. I did not benchmark performance after the bug fix, but I expect any impact to be minimal. ## Testing Done test_geglu.py tests are passing with the tighter tolerance for fp16 on a 1x RTX 5070 @Tcc0403 --------- Co-authored-by: Konstantinos Pitas <[email protected]>
1 parent 8b4c62f commit 57848d5

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/liger_kernel/ops/geglu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
6767
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
6868
tanh_result = tanh(tanh_arg)
6969
geglu_a = 0.5 * a_row * (1 + tanh_result)
70+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
7071

71-
db_row = dc_row * geglu_a
72+
db_row = dc_row.cast(tl.float32) * geglu_a
7273

7374
# Gradient w.r.t. a can be computed with:
7475
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -79,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
7980
da_row = dc_row * b_row * (term1 + term2)
8081

8182
tl.store(a + col_offsets, da_row, mask=mask)
82-
tl.store(b + col_offsets, db_row, mask=mask)
83+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
8384

8485

8586
def geglu_forward(a, b):

test/transformers/test_geglu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
(torch.float32, 1e-0, 2e-6),
3737
pytest.param(
3838
torch.bfloat16,
39-
1e4,
40-
6e-3,
39+
1e-2,
40+
1e-2,
4141
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
4242
),
4343
],
@@ -117,8 +117,8 @@ def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol,
117117
[
118118
# atol is for small values: they have more difference, so set atol higher
119119
# rtol is for larger values: they are very close, so set rtol lower
120-
(torch.float32, 1e-0, 2e-6),
121-
(torch.bfloat16, 1e4, 6e-3),
120+
(torch.float32, 1e-5, 1e-5),
121+
(torch.bfloat16, 1e-2, 1e-2),
122122
],
123123
)
124124
def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):

0 commit comments

Comments
 (0)