1212import torch .nn .functional as F
1313
1414from .fast_norm import (
15- is_fast_norm , fast_group_norm , fast_layer_norm , fast_rms_norm , rms_norm2d , fast_rms_norm2d ,
16- fast_simple_norm , simple_norm
15+ is_fast_norm ,
16+ fast_group_norm ,
17+ fast_layer_norm ,
18+ fast_rms_norm ,
19+ rms_norm2d ,
20+ fast_rms_norm2d ,
21+ fast_simple_norm ,
22+ simple_norm ,
1723)
1824
1925try :
2430
2531class GroupNorm (nn .GroupNorm ):
2632 _fast_norm : torch .jit .Final [bool ]
27-
28- def __init__ (self , num_channels , num_groups = 32 , eps = 1e-5 , affine = True ):
33+ _fp32_norm : torch .jit .Final [bool ]
34+
35+ def __init__ (
36+ self ,
37+ num_channels : int ,
38+ num_groups : int = 32 ,
39+ eps : float = 1e-5 ,
40+ affine : bool = True ,
41+ use_fp32 : bool = False ,
42+ ** kwargs ,
43+ ):
2944 # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
30- super ().__init__ (num_groups , num_channels , eps = eps , affine = affine )
45+ super ().__init__ (num_groups , num_channels , eps = eps , affine = affine , ** kwargs )
3146 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
47+ self ._fp32_norm = use_fp32
3248
3349 def forward (self , x ):
3450 if self ._fast_norm :
3551 return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
52+ elif self ._fp32_norm :
53+ return F .group_norm (x .float (), self .num_groups , self .weight , self .bias , self .eps ).to (x .dtype )
3654 else :
3755 return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
3856
@@ -42,14 +60,18 @@ class GroupNorm1(nn.GroupNorm):
4260 Input: tensor in shape [B, C, *]
4361 """
4462 _fast_norm : torch .jit .Final [bool ]
63+ _fp32_norm : torch .jit .Final [bool ]
4564
46- def __init__ (self , num_channels , ** kwargs ):
65+ def __init__ (self , num_channels : int , use_fp32 : bool = False , ** kwargs ):
4766 super ().__init__ (1 , num_channels , ** kwargs )
4867 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
68+ self ._fp32_norm = use_fp32
4969
5070 def forward (self , x : torch .Tensor ) -> torch .Tensor :
5171 if self ._fast_norm :
5272 return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
73+ elif self ._fp32_norm :
74+ return F .group_norm (x .float (), self .num_groups , self .weight , self .bias , self .eps ).to (x .device )
5375 else :
5476 return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
5577
@@ -58,14 +80,25 @@ class LayerNorm(nn.LayerNorm):
5880 """ LayerNorm w/ fast norm option
5981 """
6082 _fast_norm : torch .jit .Final [bool ]
61-
62- def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
63- super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
83+ _fp32_norm : torch .jit .Final [bool ]
84+
85+ def __init__ (
86+ self ,
87+ num_channels : int ,
88+ eps : float = 1e-6 ,
89+ affine : bool = True ,
90+ use_fp32 : bool = False ,
91+ ** kwargs ,
92+ ):
93+ super ().__init__ (num_channels , eps = eps , elementwise_affine = affine , ** kwargs )
6494 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
95+ self ._fp32_norm = use_fp32
6596
6697 def forward (self , x : torch .Tensor ) -> torch .Tensor :
6798 if self ._fast_norm :
6899 x = fast_layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
100+ elif self ._fp32_norm :
101+ x = F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
69102 else :
70103 x = F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
71104 return x
@@ -74,15 +107,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
74107class LayerNorm2d (nn .LayerNorm ):
75108 """ LayerNorm for channels of '2D' spatial NCHW tensors """
76109 _fast_norm : torch .jit .Final [bool ]
77-
78- def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
79- super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
110+ _fp32_norm : torch .jit .Final [bool ]
111+
112+ def __init__ (
113+ self ,
114+ num_channels : int ,
115+ eps : float = 1e-6 ,
116+ affine : bool = True ,
117+ use_fp32 : bool = False ,
118+ ** kwargs ,
119+ ):
120+ super ().__init__ (num_channels , eps = eps , elementwise_affine = affine , ** kwargs )
80121 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
122+ self ._fp32_norm = use_fp32
81123
82124 def forward (self , x : torch .Tensor ) -> torch .Tensor :
83125 x = x .permute (0 , 2 , 3 , 1 )
84126 if self ._fast_norm :
85127 x = fast_layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
128+ elif self ._fp32_norm :
129+ x = F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
86130 else :
87131 x = F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
88132 x = x .permute (0 , 3 , 1 , 2 )
@@ -121,7 +165,7 @@ class LayerNormExp2d(nn.LayerNorm):
121165 layout. However, benefits are not always clear and can perform worse on other GPUs.
122166 """
123167
124- def __init__ (self , num_channels , eps = 1e-6 ):
168+ def __init__ (self , num_channels : int , eps : float = 1e-6 ):
125169 super ().__init__ (num_channels , eps = eps )
126170
127171 def forward (self , x ) -> torch .Tensor :
@@ -136,13 +180,22 @@ def forward(self, x) -> torch.Tensor:
136180class RmsNorm (nn .Module ):
137181 """ RmsNorm w/ fast (apex) norm if available
138182 """
139- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
183+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' , '_fp32_norm' ]
140184 normalized_shape : Tuple [int , ...]
141185 eps : float
142186 elementwise_affine : bool
143187 _fast_norm : bool
144-
145- def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
188+ _fp32_norm : bool
189+
190+ def __init__ (
191+ self ,
192+ channels : int ,
193+ eps : float = 1e-6 ,
194+ affine : bool = True ,
195+ use_fp32 : bool = False ,
196+ device = None ,
197+ dtype = None ,
198+ ) -> None :
146199 factory_kwargs = {'device' : device , 'dtype' : dtype }
147200 super ().__init__ ()
148201 normalized_shape = channels
@@ -153,6 +206,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
153206 self .eps = eps
154207 self .elementwise_affine = affine
155208 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
209+ self ._fp32_norm = use_fp32
156210
157211 if self .elementwise_affine :
158212 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
@@ -167,9 +221,11 @@ def reset_parameters(self) -> None:
167221
168222 def forward (self , x : torch .Tensor ) -> torch .Tensor :
169223 # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
170- # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
224+ # Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed.
171225 if self ._fast_norm :
172226 x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
227+ elif self ._fp32_norm :
228+ x = rms_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
173229 else :
174230 x = rms_norm (x , self .normalized_shape , self .weight , self .eps )
175231 return x
@@ -182,13 +238,22 @@ class RmsNorm2d(nn.Module):
182238 on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
183239 like https://github.com/pytorch/pytorch/pull/150576 lands.
184240 """
185- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
241+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' , '_fp32_norm' ]
186242 normalized_shape : Tuple [int , ...]
187243 eps : float
188244 elementwise_affine : bool
189245 _fast_norm : bool
190-
191- def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
246+ _fp32_norm : bool
247+
248+ def __init__ (
249+ self ,
250+ channels : int ,
251+ eps : float = 1e-6 ,
252+ affine : bool = True ,
253+ use_fp32 : bool = False ,
254+ device = None ,
255+ dtype = None ,
256+ ) -> None :
192257 factory_kwargs = {'device' : device , 'dtype' : dtype }
193258 super ().__init__ ()
194259 normalized_shape = channels
@@ -199,6 +264,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
199264 self .eps = eps
200265 self .elementwise_affine = affine
201266 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
267+ self ._fp32_norm = use_fp32
202268
203269 if self .elementwise_affine :
204270 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
@@ -216,6 +282,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
216282 # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
217283 if self ._fast_norm :
218284 x = fast_rms_norm2d (x , self .normalized_shape , self .weight , self .eps )
285+ elif self ._fp32_norm :
286+ x = rms_norm2d (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
219287 else :
220288 x = rms_norm2d (x , self .normalized_shape , self .weight , self .eps )
221289 return x
@@ -224,13 +292,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
224292class SimpleNorm (nn .Module ):
225293 """ SimpleNorm (x / std(x))
226294 """
227- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
295+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' , '_fp32_norm' ]
228296 normalized_shape : Tuple [int , ...]
229297 eps : float
230298 elementwise_affine : bool
231299 _fast_norm : bool
232-
233- def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
300+ _fp32_norm : bool
301+
302+ def __init__ (
303+ self ,
304+ channels : int ,
305+ eps : float = 1e-6 ,
306+ affine : bool = True ,
307+ use_fp32 : bool = False ,
308+ device = None ,
309+ dtype = None ,
310+ ) -> None :
234311 factory_kwargs = {'device' : device , 'dtype' : dtype }
235312 super ().__init__ ()
236313 normalized_shape = channels
@@ -241,6 +318,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
241318 self .eps = eps
242319 self .elementwise_affine = affine
243320 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
321+ self ._fp32_norm = use_fp32
244322
245323 if self .elementwise_affine :
246324 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
@@ -256,6 +334,8 @@ def reset_parameters(self) -> None:
256334 def forward (self , x : torch .Tensor ) -> torch .Tensor :
257335 if self ._fast_norm :
258336 x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
337+ elif self ._fp32_norm :
338+ x = simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
259339 else :
260340 x = simple_norm (x , self .normalized_shape , self .weight , self .eps )
261341 return x
@@ -264,13 +344,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
264344class SimpleNorm2d (nn .Module ):
265345 """ SimpleNorm for NCHW tensors
266346 """
267- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
347+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' , '_fp32_norm' ]
268348 normalized_shape : Tuple [int , ...]
269349 eps : float
270350 elementwise_affine : bool
271351 _fast_norm : bool
272-
273- def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
352+ _fp32_norm : bool
353+
354+ def __init__ (
355+ self ,
356+ channels : int ,
357+ eps : float = 1e-6 ,
358+ affine : bool = True ,
359+ use_fp32 : bool = False ,
360+ device = None ,
361+ dtype = None ,
362+ ) -> None :
274363 factory_kwargs = {'device' : device , 'dtype' : dtype }
275364 super ().__init__ ()
276365 normalized_shape = channels
@@ -281,6 +370,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
281370 self .eps = eps
282371 self .elementwise_affine = affine
283372 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
373+ self ._fp32_norm = use_fp32
284374
285375 if self .elementwise_affine :
286376 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
@@ -297,6 +387,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
297387 x = x .permute (0 , 2 , 3 , 1 )
298388 if self ._fast_norm :
299389 x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
390+ elif self ._fp32_norm :
391+ x = simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
300392 else :
301393 x = simple_norm (x , self .normalized_shape , self .weight , self .eps )
302394 x = x .permute (0 , 3 , 1 , 2 )
0 commit comments