Skip to content

Commit fd0d54b

Browse files
authored
[Modules] Enhance Testing of l2warp (#448)
1 parent 6ebb28a commit fd0d54b

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

fla/modules/l2warp.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,28 @@ class L2Wrap(torch.autograd.Function):
1111
This version is memory-optimized by not storing the full logits tensor.
1212
"""
1313
@staticmethod
14-
def forward(ctx, loss, logits):
15-
ctx.save_for_backward(logits)
14+
def forward(ctx, loss, logits, l2_penalty_factor=1e-4):
15+
"""
16+
Forward pass for L2 penalty.
17+
Args:
18+
loss (torch.Tensor): The loss tensor.
19+
logits (torch.Tensor): Shape[B, T, V] The logits tensor.
20+
l2_penalty_factor (float): The factor for L2 penalty.
21+
"""
22+
maxx, ids = torch.max(logits, dim=-1, keepdim=True)
23+
ctx.logits_shape = logits.shape
24+
factor = l2_penalty_factor / (logits.shape[0] * logits.shape[1])
25+
maxx = maxx * factor
26+
ctx.save_for_backward(maxx, ids)
1627
return loss
1728

1829
@staticmethod
1930
def backward(ctx, grad_output):
20-
logits = ctx.saved_tensors[0]
21-
22-
factor = 1e-4 / (logits.shape[0] * logits.shape[1])
23-
maxx, ids = torch.max(logits, -1, keepdim=True)
24-
25-
glogits = torch.zeros_like(logits)
26-
penalty_grad = maxx * factor
27-
glogits.scatter_(-1, ids, penalty_grad)
28-
29-
return grad_output, glogits
31+
maxx, ids = ctx.saved_tensors
32+
glogits = torch.zeros(ctx.logits_shape, device=grad_output.device,
33+
dtype=grad_output.dtype)
34+
glogits.scatter_(-1, ids, maxx)
35+
return grad_output, glogits, None
3036

3137

3238
l2_warp = L2Wrap.apply

tests/modules/test_l2warp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
@pytest.mark.parametrize("T", [1024])
1616
@pytest.mark.parametrize("H", [256])
1717
@pytest.mark.parametrize("V", [2000])
18-
@pytest.mark.parametrize("l2_penalty_factor", [1e-4])
18+
@pytest.mark.parametrize("l2_penalty_factor", [1e-4, 1])
1919
@pytest.mark.skipif(
2020
is_intel_alchemist is True,
2121
reason="Intel Triton Failure"
@@ -41,7 +41,7 @@ def test_fused_linear_cross_entropy_l2_warp(
4141

4242
ref_logits = F.linear(x.view(-1, H), lm_head.weight, lm_head.bias)
4343
ref_loss_ce = ref_criterion(ref_logits.view(B * T, V), shift_labels.view(-1))
44-
ref_loss = standalone_l2_warp(ref_loss_ce, ref_logits.view(B, T, V))
44+
ref_loss = standalone_l2_warp(ref_loss_ce, ref_logits.view(B, T, V), l2_penalty_factor)
4545

4646
ref_loss.backward()
4747
ref_x_grad = x.grad.clone()
@@ -63,7 +63,7 @@ def test_fused_linear_cross_entropy_l2_warp(
6363
fused_w_grad = lm_head.weight.grad.clone()
6464
fused_b_grad = lm_head.bias.grad.clone()
6565

66-
ratio = 4e-3 if dtype == torch.bfloat16 else 1e-5
66+
ratio = 4e-3 if dtype == torch.bfloat16 else 1e-3
6767

6868
assert_close("Loss", ref_loss, fused_loss, ratio)
6969
assert_close("dx", ref_x_grad, fused_x_grad, ratio)

0 commit comments

Comments
 (0)