11# SPDX-License-Identifier: MIT
2- # Copyright (C) 2024-2025 , Advanced Micro Devices, Inc. All rights reserved.
2+ # Copyright (C) 2024-2026 , Advanced Micro Devices, Inc. All rights reserved.
33
44import torch
55from torch import Tensor
@@ -59,16 +59,20 @@ def rms_norm(
5959 ...
6060
6161
62- @compile_ops ("module_rmsnorm" , gen_fake = gen_rms_norm_fake_tensor )
6362def rmsnorm2d_fwd (
6463 input : torch .Tensor ,
6564 weight : torch .Tensor ,
6665 epsilon : float ,
6766 use_model_sensitive_rmsnorm : int = 0 ,
68- ) -> Tensor : ...
67+ ) -> Tensor :
68+ out = torch .empty_like (input , dtype = input .dtype , device = input .device )
69+ if use_model_sensitive_rmsnorm > 0 or input .shape [- 1 ] > 8192 :
70+ rmsnorm2d_fwd_ck (out , input , weight , epsilon , use_model_sensitive_rmsnorm )
71+ else :
72+ rmsnorm (out , input , weight , epsilon )
73+ return out
6974
7075
71- @compile_ops ("module_rmsnorm" )
7276def rmsnorm2d_fwd_with_add (
7377 out : Tensor ,
7478 input : Tensor ,
@@ -77,7 +81,19 @@ def rmsnorm2d_fwd_with_add(
7781 weight : Tensor ,
7882 epsilon : float ,
7983 use_model_sensitive_rmsnorm : int = 0 ,
80- ) -> None : ...
84+ ) -> None :
85+ if use_model_sensitive_rmsnorm > 0 or input .shape [- 1 ] > 8192 :
86+ rmsnorm2d_fwd_with_add_ck (
87+ out ,
88+ input ,
89+ residual_in ,
90+ residual_out ,
91+ weight ,
92+ epsilon ,
93+ use_model_sensitive_rmsnorm ,
94+ )
95+ else :
96+ add_rmsnorm (out , input , residual_in , residual_out , weight , epsilon )
8197
8298
8399@compile_ops ("module_rmsnorm" )
@@ -107,18 +123,26 @@ def rmsnorm2d_fwd_with_add_smoothquant(
107123) -> None : ...
108124
109125
110- @compile_ops ("module_rmsnorm" )
111126def rmsnorm2d_fwd_with_dynamicquant (
112127 out : Tensor ,
113128 input : Tensor ,
114129 yscale : Tensor ,
115130 weight : Tensor ,
116131 epsilon : float ,
117132 use_model_sensitive_rmsnorm : int = 0 ,
118- ) -> None : ...
133+ group_size : int = 0 ,
134+ shuffle_scale : bool = False ,
135+ ) -> None :
136+ if use_model_sensitive_rmsnorm > 0 or input .shape [- 1 ] > 8192 :
137+ assert group_size == 0 , "group_size is not supported for ck rmsnorm"
138+ assert not shuffle_scale , "shuffle_scale is not supported for ck rmsnorm"
139+ rmsnorm2d_fwd_with_dynamicquant_ck (
140+ out , input , yscale , weight , epsilon , use_model_sensitive_rmsnorm
141+ )
142+ else :
143+ rmsnorm_quant (out , input , yscale , weight , epsilon , group_size , shuffle_scale )
119144
120145
121- @compile_ops ("module_rmsnorm" )
122146def rmsnorm2d_fwd_with_add_dynamicquant (
123147 out : Tensor ,
124148 input : Tensor ,
@@ -128,4 +152,124 @@ def rmsnorm2d_fwd_with_add_dynamicquant(
128152 weight : Tensor ,
129153 epsilon : float ,
130154 use_model_sensitive_rmsnorm : int = 0 ,
155+ group_size : int = 0 ,
156+ shuffle_scale : bool = False ,
157+ ) -> None :
158+ if use_model_sensitive_rmsnorm > 0 or input .shape [- 1 ] > 8192 :
159+ assert group_size == 0 , "group_size is not supported for ck rmsnorm"
160+ assert not shuffle_scale , "shuffle_scale is not supported for ck rmsnorm"
161+ rmsnorm2d_fwd_with_add_dynamicquant_ck (
162+ out ,
163+ input ,
164+ residual_in ,
165+ residual_out ,
166+ yscale ,
167+ weight ,
168+ epsilon ,
169+ use_model_sensitive_rmsnorm ,
170+ )
171+ else :
172+ add_rmsnorm_quant (
173+ out ,
174+ input ,
175+ residual_in ,
176+ residual_out ,
177+ yscale ,
178+ weight ,
179+ epsilon ,
180+ group_size ,
181+ shuffle_scale ,
182+ )
183+
184+
185+ @compile_ops (
186+ "module_rmsnorm" , gen_fake = gen_rms_norm_fake_tensor , fc_name = "rmsnorm2d_fwd"
187+ )
188+ def rmsnorm2d_fwd_ck (
189+ input : torch .Tensor ,
190+ weight : torch .Tensor ,
191+ epsilon : float ,
192+ use_model_sensitive_rmsnorm : int = 0 ,
193+ ) -> Tensor : ...
194+
195+
196+ @compile_ops ("module_rmsnorm" , fc_name = "rmsnorm2d_fwd_with_add" )
197+ def rmsnorm2d_fwd_with_add_ck (
198+ out : Tensor ,
199+ input : Tensor ,
200+ residual_in : Tensor ,
201+ residual_out : Tensor ,
202+ weight : Tensor ,
203+ epsilon : float ,
204+ use_model_sensitive_rmsnorm : int = 0 ,
205+ ) -> None : ...
206+
207+
208+ @compile_ops ("module_rmsnorm" , fc_name = "rmsnorm2d_fwd_with_dynamicquant" )
209+ def rmsnorm2d_fwd_with_dynamicquant_ck (
210+ out : Tensor ,
211+ input : Tensor ,
212+ yscale : Tensor ,
213+ weight : Tensor ,
214+ epsilon : float ,
215+ use_model_sensitive_rmsnorm : int = 0 ,
216+ ) -> None : ...
217+
218+
219+ @compile_ops ("module_rmsnorm" , fc_name = "rmsnorm2d_fwd_with_add_dynamicquant" )
220+ def rmsnorm2d_fwd_with_add_dynamicquant_ck (
221+ out : Tensor ,
222+ input : Tensor ,
223+ residual_in : Tensor ,
224+ residual_out : Tensor ,
225+ yscale : Tensor ,
226+ weight : Tensor ,
227+ epsilon : float ,
228+ use_model_sensitive_rmsnorm : int = 0 ,
229+ ) -> None : ...
230+
231+
232+ @compile_ops ("module_rmsnorm_quant" )
233+ def add_rmsnorm_quant (
234+ out : Tensor ,
235+ input : Tensor ,
236+ residual_in : Tensor ,
237+ residual_out : Tensor ,
238+ scale : Tensor ,
239+ weight : Tensor ,
240+ epsilon : float ,
241+ group_size : int = 0 ,
242+ shuffle_scale : bool = False ,
243+ ) -> None : ...
244+
245+
246+ @compile_ops ("module_rmsnorm_quant" )
247+ def add_rmsnorm (
248+ out : Tensor ,
249+ input : Tensor ,
250+ residual_in : Tensor ,
251+ residual_out : Tensor ,
252+ weight : Tensor ,
253+ epsilon : float ,
254+ ) -> None : ...
255+
256+
257+ @compile_ops ("module_rmsnorm_quant" )
258+ def rmsnorm_quant (
259+ out : Tensor ,
260+ input : Tensor ,
261+ scale : Tensor ,
262+ weight : Tensor ,
263+ epsilon : float ,
264+ group_size : int = 0 ,
265+ shuffle_scale : bool = False ,
266+ ) -> None : ...
267+
268+
269+ @compile_ops ("module_rmsnorm_quant" )
270+ def rmsnorm (
271+ out : Tensor ,
272+ input : Tensor ,
273+ weight : Tensor ,
274+ epsilon : float ,
131275) -> None : ...
0 commit comments