Skip to content

Commit 6a2c491

Browse files
Manan17Manan Shahvaibhavjindallancerts
authored
Layernorm enhancement (#815)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This is an enhancement to the layer_norm kernel. Majorly the backward kernel. Changes in forward: 1. The new version performs all computations in float32 regardless of input dtype, then converts back only once at the end. This prevents precision loss that can accumulate with float16/bfloat16. 2. The new version pre-calculates row pointers instead of modifying the base pointers. This is cleaner and potentially more efficient for the compiler. 3. Eliminates the tl.where masking operation (relies on loading with other=0.0) Changes in backward: 1. The original processes multiple rows per thread block in a loop, while the optimized version processes exactly one row per thread block. 2. The original uses local accumulation then stores to separate locations, while the optimized uses atomic `tl.atomic_add` operations to accumulate directly to the final gradient tensors. 3. The original requires a separate reduction step in Python, while the optimized kernel produces final gradients directly. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Here are the benchmark results for speed: Before: Backward pass- <img width="1000" height="600" alt="layer_norm_speed_backward" src="https://github.com/user-attachments/assets/12b38b8f-f27f-4c48-ac8e-0a2f462f4372" /> After: Backward pass - <img width="1000" height="600" alt="layer_norm_speed_backward_new" src="https://github.com/user-attachments/assets/23dd7fe2-4e1e-4698-b4d6-da4db30a8544" /> Before: Full - <img width="1000" height="600" alt="layer_norm_speed_full" src="https://github.com/user-attachments/assets/1422a070-678f-4437-a5e5-37db47dd3171" /> After: Full- <img width="1000" height="600" alt="layer_norm_speed_full_new" src="https://github.com/user-attachments/assets/a702318e-b76a-4cbf-9a73-52a7637b63f8" /> This layer norm is 27% faster than Huggingface's implementation. The benchmarks for forwards pass (speed) and full pass (memory) are the same. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Ran: python -m pytest test/transformers/test_layer_norm.py <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: NVIDIA H100 80GB HBM3 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Manan Shah <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]> Co-authored-by: Shao Tang <[email protected]>
1 parent 1965405 commit 6a2c491

File tree

3 files changed

+170
-119
lines changed

3 files changed

+170
-119
lines changed

benchmark/data/all_benchmark_data.csv

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -625,36 +625,6 @@ group_norm,huggingface,backward,memory,MB,C,num_channels,256,320.5078125,320.507
625625
group_norm,huggingface,backward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1
626626
group_norm,huggingface,backward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1
627627
group_norm,huggingface,backward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1
628-
layer_norm,liger,forward,speed,ms,N,hidden size,1024,0.035840000957250595,0.03481600061058998,0.035840000957250595,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1
629-
layer_norm,liger,forward,speed,ms,N,hidden size,2048,0.05939200147986412,0.058368001133203506,0.060416001826524734,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1
630-
layer_norm,liger,forward,speed,ms,N,hidden size,4096,0.10751999914646149,0.10751999914646149,0.1085439994931221,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1
631-
layer_norm,liger,forward,speed,ms,N,hidden size,8192,0.20582400262355804,0.20479999482631683,0.20684799551963806,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1
632-
layer_norm,liger,forward,speed,ms,N,hidden size,16384,0.3993600010871887,0.3983359932899475,0.40140798687934875,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1
633-
layer_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.03788800165057182,0.03788800165057182,0.03891199827194214,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1
634-
layer_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.0655359998345375,0.0655359998345375,0.06656000018119812,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1
635-
layer_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.14745600521564484,0.14643199741840363,0.14847999811172485,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1
636-
layer_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.31334400177001953,0.3123199939727783,0.31436800956726074,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1
637-
layer_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.6133760213851929,0.6123520135879517,0.6154239773750305,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1
638-
layer_norm,liger,full,speed,ms,N,hidden size,1024,0.6860799789428711,0.6146048903465271,0.7049216032028198,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1
639-
layer_norm,liger,full,speed,ms,N,hidden size,2048,0.6789119839668274,0.6737920045852661,0.6912000179290771,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1
640-
layer_norm,liger,full,speed,ms,N,hidden size,4096,0.6686720252037048,0.6635519862174988,0.681984007358551,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1
641-
layer_norm,liger,full,speed,ms,N,hidden size,8192,0.6789119839668274,0.5908480286598206,0.6932479739189148,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1
642-
layer_norm,liger,full,speed,ms,N,hidden size,16384,6.071296215057373,5.331148624420166,6.08235502243042,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1
643-
layer_norm,huggingface,full,speed,ms,N,hidden size,1024,0.13312000036239624,0.13209599256515503,0.13312000036239624,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
644-
layer_norm,huggingface,full,speed,ms,N,hidden size,2048,0.23244799673557281,0.2303999960422516,0.23347200453281403,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
645-
layer_norm,huggingface,full,speed,ms,N,hidden size,4096,0.5242879986763,0.5232639908790588,0.5263360142707825,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
646-
layer_norm,huggingface,full,speed,ms,N,hidden size,8192,1.0168319940567017,1.0147839784622192,1.018880009651184,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
647-
layer_norm,huggingface,full,speed,ms,N,hidden size,16384,1.994752049446106,1.9916800260543823,1.9967999458312988,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
648-
layer_norm,liger,full,memory,MB,N,hidden size,1024,80.90625,80.90625,80.90625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
649-
layer_norm,liger,full,memory,MB,N,hidden size,2048,161.78125,161.78125,161.78125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
650-
layer_norm,liger,full,memory,MB,N,hidden size,4096,323.53125,323.53125,323.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
651-
layer_norm,liger,full,memory,MB,N,hidden size,8192,647.03125,647.03125,647.03125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
652-
layer_norm,liger,full,memory,MB,N,hidden size,16384,1294.03125,1294.03125,1294.03125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1
653-
layer_norm,huggingface,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
654-
layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
655-
layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
656-
layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
657-
layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
658628
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,116.00621032714844,116.00621032714844,116.00621032714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
659629
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,230.83609008789062,230.83609008789062,230.83609008789062,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
660630
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,461.9543151855469,461.9543151855469,461.9543151855469,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
@@ -1493,6 +1463,46 @@ distill_cosine_loss,torch,full,memory,MB,BT,B x T,1024,7566.2822265625,7566.2822
14931463
distill_cosine_loss,torch,full,memory,MB,BT,B x T,2048,11590.3134765625,11590.3134765625,11590.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
14941464
distill_cosine_loss,torch,full,memory,MB,BT,B x T,4096,19654.375,19654.375,19654.375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
14951465
distill_cosine_loss,torch,full,memory,MB,BT,B x T,8192,35782.5,35782.5,35782.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
1466+
layer_norm,liger,forward,speed,ms,N,hidden size,1024,0.018848000094294548,0.018400000408291817,0.020102400332689285,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0
1467+
layer_norm,liger,forward,speed,ms,N,hidden size,2048,0.029152000322937965,0.02876799926161766,0.029823999851942062,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0
1468+
layer_norm,liger,forward,speed,ms,N,hidden size,4096,0.05104000121355057,0.05036799982190132,0.05177599936723709,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0
1469+
layer_norm,liger,forward,speed,ms,N,hidden size,8192,0.0947519987821579,0.09436800330877304,0.09507200121879578,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0
1470+
layer_norm,liger,forward,speed,ms,N,hidden size,16384,0.18476800620555878,0.18396799266338348,0.1852159947156906,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0
1471+
layer_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.023584000766277313,0.023423999547958374,0.023840000852942467,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0
1472+
layer_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.03734400123357773,0.03702399879693985,0.037811201065778746,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0
1473+
layer_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.06617599725723267,0.06560000032186508,0.06678400188684464,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0
1474+
layer_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.15267199277877808,0.15190400183200836,0.15347200632095337,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0
1475+
layer_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.3067840039730072,0.3046143889427185,0.3081152021884918,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0
1476+
layer_norm,liger,backward,speed,ms,N,hidden size,1024,0.12006399780511856,0.11653760075569153,0.12467200309038162,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0
1477+
layer_norm,liger,backward,speed,ms,N,hidden size,2048,0.1207360029220581,0.1176128014922142,0.1256511986255646,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0
1478+
layer_norm,liger,backward,speed,ms,N,hidden size,4096,0.16630400717258453,0.16412800550460815,0.16838400065898895,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0
1479+
layer_norm,liger,backward,speed,ms,N,hidden size,8192,0.31279999017715454,0.31116798520088196,0.3145279884338379,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0
1480+
layer_norm,liger,backward,speed,ms,N,hidden size,16384,0.5776320099830627,0.5753471970558167,0.5798912048339844,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0
1481+
layer_norm,huggingface,backward,speed,ms,N,hidden size,1024,0.0605119988322258,0.059647999703884125,0.061344001442193985,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0
1482+
layer_norm,huggingface,backward,speed,ms,N,hidden size,2048,0.09967999905347824,0.09849599748849869,0.10099200159311295,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0
1483+
layer_norm,huggingface,backward,speed,ms,N,hidden size,4096,0.17881600558757782,0.17795200645923615,0.17971199750900269,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0
1484+
layer_norm,huggingface,backward,speed,ms,N,hidden size,8192,0.33369600772857666,0.3328000009059906,0.33478400111198425,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0
1485+
layer_norm,huggingface,backward,speed,ms,N,hidden size,16384,0.6424000263214111,0.6412223815917969,0.643455982208252,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0
1486+
layer_norm,liger,full,speed,ms,N,hidden size,1024,0.26576000452041626,0.2629248082637787,0.2701759934425354,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0
1487+
layer_norm,liger,full,speed,ms,N,hidden size,2048,0.27427199482917786,0.26999040842056277,0.28091518878936766,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0
1488+
layer_norm,liger,full,speed,ms,N,hidden size,4096,0.27454400062561035,0.27004799246788025,0.2807359993457794,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0
1489+
layer_norm,liger,full,speed,ms,N,hidden size,8192,0.40556800365448,0.40403199195861816,0.40723198652267456,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0
1490+
layer_norm,liger,full,speed,ms,N,hidden size,16384,0.7608960270881653,0.7589311957359314,0.7631679773330688,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0
1491+
layer_norm,huggingface,full,speed,ms,N,hidden size,1024,0.08025600016117096,0.07942400127649307,0.08111999928951263,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1492+
layer_norm,huggingface,full,speed,ms,N,hidden size,2048,0.13315199315547943,0.13180799782276154,0.13468800485134125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1493+
layer_norm,huggingface,full,speed,ms,N,hidden size,4096,0.2417600005865097,0.24089600145816803,0.24262399971485138,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1494+
layer_norm,huggingface,full,speed,ms,N,hidden size,8192,0.4832639992237091,0.48214399814605713,0.4843647956848145,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1495+
layer_norm,huggingface,full,speed,ms,N,hidden size,16384,0.950575977563858,0.9484800100326538,0.9528064012527466,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1496+
layer_norm,liger,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1497+
layer_norm,liger,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1498+
layer_norm,liger,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1499+
layer_norm,liger,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1500+
layer_norm,liger,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1501+
layer_norm,huggingface,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1502+
layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1503+
layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1504+
layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
1505+
layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0
14961506
fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,1024,0.01759999990463257,0.017311999574303627,0.017920000478625298,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0
14971507
fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,2048,0.02924799919128418,0.028863999992609024,0.029983999207615852,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0
14981508
fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,4096,0.05129599943757057,0.050624001771211624,0.05209600180387497,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0
@@ -1564,4 +1574,4 @@ fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,2048,208.06298828
15641574
fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,4096,416.11767578125,416.11767578125,416.11767578125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
15651575
fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,8192,832.22705078125,832.22705078125,832.22705078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
15661576
fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,16384,1544.44580078125,1544.44580078125,1544.44580078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1567-
fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,32768,2960.8837890625,2960.8837890625,2960.8837890625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1577+
fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,32768,2960.8837890625,2960.8837890625,2960.8837890625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0

0 commit comments

Comments
 (0)