@@ -352,3 +352,229 @@ def test_amp(B, T, H, V, bias, cast_dtype, accum_dtype, atol, rtol):
352352 atol = atol ,
353353 rtol = rtol ,
354354 )
355+
356+
357+ def test_correctness_token_scaling ():
358+ """Test that token scaling produces the correct loss values and gradients."""
359+ B , T , H , V = 2 , 4 , 8 , 16
360+ dtype = torch .float32
361+
362+ # Create inputs
363+ _input = torch .randn (B * T , H , device = device , dtype = dtype , requires_grad = True )
364+ target = torch .randint (0 , V , (B * T ,), device = device , dtype = torch .long )
365+
366+ # Create weights
367+ weight = torch .randn (V , H , device = device , dtype = dtype )
368+ bias = torch .randn (V , device = device , dtype = dtype )
369+
370+ # Test using functional API with token scaling
371+ loss_scaled = liger_fused_linear_cross_entropy (
372+ input = _input ,
373+ weight = weight ,
374+ target = target ,
375+ bias = bias ,
376+ ignore_index = - 100 ,
377+ reduction = "none" , # Use "none" to get per-token losses
378+ use_token_scaling = True ,
379+ )
380+
381+ # Compare with manual implementation
382+ # Compute logits
383+ logits = _input @ weight .t ()
384+ if bias is not None :
385+ logits = logits + bias
386+
387+ # Compute standard cross entropy loss per token
388+ ce_loss = torch .nn .functional .cross_entropy (logits , target , ignore_index = - 100 , reduction = "none" )
389+
390+ # Compute predicted probabilities for target tokens
391+ pred_probs = torch .softmax (logits , dim = - 1 ).gather (1 , target .unsqueeze (- 1 )).squeeze (- 1 ).detach ()
392+
393+ # Scale by predicted probabilities
394+ expected_loss = ce_loss * pred_probs
395+
396+ # Check that losses are close
397+ assert torch .allclose (loss_scaled , expected_loss , atol = 1e-4 , rtol = 1e-4 )
398+
399+ # Test gradients
400+ loss_scaled .sum ().backward (retain_graph = True )
401+ grad_scaled = _input .grad .clone ()
402+ _input .grad .zero_ ()
403+
404+ expected_loss .sum ().backward (retain_graph = True )
405+ grad_expected = _input .grad .clone ()
406+ _input .grad .zero_ ()
407+
408+ # Check that gradients are close
409+ assert torch .allclose (grad_scaled , grad_expected , atol = 1e-4 , rtol = 1e-4 )
410+
411+
412+ def test_correctness_token_scaling_consistency ():
413+ """Test that token scaling is consistent between functional and module APIs."""
414+ B , T , H , V = 2 , 4 , 8 , 16
415+ dtype = torch .float32
416+
417+ # Create inputs
418+ _input = torch .randn (B * T , H , device = device , dtype = dtype , requires_grad = True )
419+ target = torch .randint (0 , V , (B * T ,), device = device , dtype = torch .long )
420+
421+ # Create weights
422+ weight = torch .randn (V , H , device = device , dtype = dtype )
423+ bias = torch .randn (V , device = device , dtype = dtype )
424+
425+ # Test functional API
426+ loss_functional = liger_fused_linear_cross_entropy (
427+ input = _input ,
428+ weight = weight ,
429+ target = target ,
430+ bias = bias ,
431+ ignore_index = - 100 ,
432+ reduction = "sum" ,
433+ use_token_scaling = True ,
434+ )
435+
436+ # Test module API
437+ ce_loss_module = LigerFusedLinearCrossEntropyLoss (
438+ ignore_index = - 100 ,
439+ reduction = "sum" ,
440+ use_token_scaling = True ,
441+ )
442+
443+ loss_module = ce_loss_module (weight , _input , target , bias )
444+
445+ # Check that losses are identical
446+ assert torch .allclose (loss_functional , loss_module , atol = 1e-6 , rtol = 1e-6 )
447+
448+ # Test gradients
449+ loss_functional .backward (retain_graph = True )
450+ grad_functional = _input .grad .clone ()
451+ _input .grad .zero_ ()
452+
453+ loss_module .backward (retain_graph = True )
454+ grad_module = _input .grad .clone ()
455+ _input .grad .zero_ ()
456+
457+ # Check that gradients are identical
458+ assert torch .allclose (grad_functional , grad_module , atol = 1e-6 , rtol = 1e-6 )
459+
460+
461+ def test_correctness_token_scaling_functional ():
462+ """Test token scaling using the functional API."""
463+ B , T , H , V = 2 , 4 , 8 , 16
464+ dtype = torch .float32
465+
466+ # Create inputs
467+ _input = torch .randn (B * T , H , device = device , dtype = dtype )
468+ x1 = _input .detach ().clone ().requires_grad_ (True )
469+ x2 = _input .detach ().clone ().requires_grad_ (True )
470+
471+ target = torch .randint (0 , V , (B * T ,), device = device , dtype = torch .long )
472+
473+ # Create weights
474+ weight = torch .randn (V , H , device = device , dtype = dtype )
475+ bias = torch .randn (V , device = device , dtype = dtype )
476+
477+ # Test using functional API with token scaling
478+ y1 = liger_fused_linear_cross_entropy (
479+ input = x1 ,
480+ weight = weight ,
481+ target = target ,
482+ bias = bias ,
483+ ignore_index = - 100 ,
484+ lse_square_scale = 0.0 ,
485+ label_smoothing = 0.0 ,
486+ reduction = "sum" , # Use sum for easier verification
487+ softcap = None ,
488+ return_z_loss = False ,
489+ accum_dtype = None ,
490+ use_token_scaling = True ,
491+ )
492+
493+ # Compare with manual implementation
494+ # Compute logits
495+ logits = x2 @ weight .t ()
496+ if bias is not None :
497+ logits = logits + bias
498+
499+ # Compute softmax probabilities
500+ probs = torch .softmax (logits .detach (), dim = - 1 ) # Detach to avoid gradient flow
501+
502+ # Get predicted probabilities for target tokens
503+ pred_probs = torch .gather (probs , - 1 , target .unsqueeze (- 1 )).squeeze (- 1 )
504+
505+ # Compute standard cross entropy loss
506+ ce_loss = torch .nn .functional .cross_entropy (logits , target , ignore_index = - 100 , reduction = "none" )
507+
508+ # Scale by predicted probabilities
509+ scaled_loss = ce_loss * pred_probs
510+
511+ # Sum over all tokens
512+ y2 = scaled_loss .sum ()
513+
514+ # Check that losses are close
515+ assert torch .allclose (y1 , y2 , atol = 1e-5 , rtol = 1e-5 )
516+
517+ # Test gradients
518+ y1 .backward ()
519+ y2 .backward ()
520+
521+ # Check that gradients are close
522+ assert torch .allclose (x1 .grad , x2 .grad , atol = 1e-5 , rtol = 1e-5 )
523+
524+
525+ def test_correctness_token_scaling_module ():
526+ """Test token scaling using the module API."""
527+ B , T , H , V = 2 , 4 , 8 , 16
528+ dtype = torch .float32
529+
530+ # Create inputs
531+ _input = torch .randn (B * T , H , device = device , dtype = dtype )
532+ x1 = _input .detach ().clone ().requires_grad_ (True )
533+ x2 = _input .detach ().clone ().requires_grad_ (True )
534+
535+ target = torch .randint (0 , V , (B * T ,), device = device , dtype = torch .long )
536+
537+ # Create module with token scaling
538+ ce_loss = LigerFusedLinearCrossEntropyLoss (
539+ ignore_index = - 100 ,
540+ reduction = "sum" ,
541+ use_token_scaling = True ,
542+ )
543+
544+ # Create weights
545+ weight = torch .randn (V , H , device = device , dtype = dtype )
546+ bias = torch .randn (V , device = device , dtype = dtype )
547+
548+ # Test using module API with token scaling
549+ y1 = ce_loss (weight , x1 , target , bias )
550+
551+ # Compare with manual implementation
552+ # Compute logits
553+ logits = x2 @ weight .t ()
554+ if bias is not None :
555+ logits = logits + bias
556+
557+ # Compute softmax probabilities
558+ probs = torch .softmax (logits .detach (), dim = - 1 ) # Detach to avoid gradient flow
559+
560+ # Get predicted probabilities for target tokens
561+ pred_probs = torch .gather (probs , - 1 , target .unsqueeze (- 1 )).squeeze (- 1 )
562+
563+ # Compute standard cross entropy loss
564+ ce_loss_manual = torch .nn .functional .cross_entropy (logits , target , ignore_index = - 100 , reduction = "none" )
565+
566+ # Scale by predicted probabilities
567+ scaled_loss = ce_loss_manual * pred_probs
568+
569+ # Sum over all tokens
570+ y2 = scaled_loss .sum ()
571+
572+ # Check that losses are close
573+ assert torch .allclose (y1 , y2 , atol = 1e-5 , rtol = 1e-5 )
574+
575+ # Test gradients
576+ y1 .backward ()
577+ y2 .backward ()
578+
579+ # Check that gradients are close
580+ assert torch .allclose (x1 .grad , x2 .grad , atol = 1e-5 , rtol = 1e-5 )
0 commit comments