Skip to content

Commit c08db89

Browse files
committed
Add flag to enable float32 computation for normalization (norm + affine) in several timm norm & norm+act layers
1 parent 6239313 commit c08db89

File tree

3 files changed

+250
-101
lines changed

3 files changed

+250
-101
lines changed

timm/layers/create_act.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
""" Activation Factory
22
Hacked together by / Copyright 2020 Ross Wightman
33
"""
4-
from typing import Union, Callable, Type
4+
from typing import Callable, Optional, Type, Union
55

66
from .activations import *
77
from .activations_me import *
88
from .config import is_exportable, is_scriptable
9+
from .typing import LayerType
910

1011
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
1112
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
@@ -88,7 +89,7 @@
8889
a.setdefault('hardswish', a.get('hard_swish'))
8990

9091

91-
def get_act_fn(name: Union[Callable, str] = 'relu'):
92+
def get_act_fn(name: Optional[LayerType] = 'relu'):
9293
""" Activation Function Factory
9394
Fetching activation fns by name with this function allows export or torch script friendly
9495
functions to be returned dynamically based on current config.
@@ -106,7 +107,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
106107
return _ACT_FN_DEFAULT[name]
107108

108109

109-
def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
110+
def get_act_layer(name: Optional[LayerType] = 'relu'):
110111
""" Activation Layer Factory
111112
Fetching activation layers by name with this function allows export or torch script friendly
112113
functions to be returned dynamically based on current config.
@@ -125,7 +126,11 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
125126
return _ACT_LAYER_DEFAULT[name]
126127

127128

128-
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
129+
def create_act_layer(
130+
name: Optional[LayerType],
131+
inplace: Optional[bool] = None,
132+
**kwargs
133+
):
129134
act_layer = get_act_layer(name)
130135
if act_layer is None:
131136
return None

timm/layers/norm.py

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212
import torch.nn.functional as F
1313

1414
from .fast_norm import (
15-
is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d,
16-
fast_simple_norm, simple_norm
15+
is_fast_norm,
16+
fast_group_norm,
17+
fast_layer_norm,
18+
fast_rms_norm,
19+
rms_norm2d,
20+
fast_rms_norm2d,
21+
fast_simple_norm,
22+
simple_norm,
1723
)
1824

1925
try:
@@ -24,15 +30,27 @@
2430

2531
class GroupNorm(nn.GroupNorm):
2632
_fast_norm: torch.jit.Final[bool]
27-
28-
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
33+
_fp32_norm: torch.jit.Final[bool]
34+
35+
def __init__(
36+
self,
37+
num_channels: int,
38+
num_groups: int = 32,
39+
eps: float = 1e-5,
40+
affine: bool = True,
41+
use_fp32: bool = False,
42+
**kwargs,
43+
):
2944
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
30-
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
45+
super().__init__(num_groups, num_channels, eps=eps, affine=affine, **kwargs)
3146
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
47+
self._fp32_norm = use_fp32
3248

3349
def forward(self, x):
3450
if self._fast_norm:
3551
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
52+
elif self._fp32_norm:
53+
return F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.dtype)
3654
else:
3755
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
3856

@@ -42,14 +60,18 @@ class GroupNorm1(nn.GroupNorm):
4260
Input: tensor in shape [B, C, *]
4361
"""
4462
_fast_norm: torch.jit.Final[bool]
63+
_fp32_norm: torch.jit.Final[bool]
4564

46-
def __init__(self, num_channels, **kwargs):
65+
def __init__(self, num_channels: int, use_fp32: bool = False, **kwargs):
4766
super().__init__(1, num_channels, **kwargs)
4867
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
68+
self._fp32_norm = use_fp32
4969

5070
def forward(self, x: torch.Tensor) -> torch.Tensor:
5171
if self._fast_norm:
5272
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
73+
elif self._fp32_norm:
74+
return F.group_norm(x.float(), self.num_groups, self.weight, self.bias, self.eps).to(x.device)
5375
else:
5476
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
5577

@@ -58,14 +80,25 @@ class LayerNorm(nn.LayerNorm):
5880
""" LayerNorm w/ fast norm option
5981
"""
6082
_fast_norm: torch.jit.Final[bool]
61-
62-
def __init__(self, num_channels, eps=1e-6, affine=True):
63-
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
83+
_fp32_norm: torch.jit.Final[bool]
84+
85+
def __init__(
86+
self,
87+
num_channels: int,
88+
eps: float = 1e-6,
89+
affine: bool = True,
90+
use_fp32: bool = False,
91+
**kwargs,
92+
):
93+
super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
6494
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
95+
self._fp32_norm = use_fp32
6596

6697
def forward(self, x: torch.Tensor) -> torch.Tensor:
6798
if self._fast_norm:
6899
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
100+
elif self._fp32_norm:
101+
x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype)
69102
else:
70103
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
71104
return x
@@ -74,15 +107,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
74107
class LayerNorm2d(nn.LayerNorm):
75108
""" LayerNorm for channels of '2D' spatial NCHW tensors """
76109
_fast_norm: torch.jit.Final[bool]
77-
78-
def __init__(self, num_channels, eps=1e-6, affine=True):
79-
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
110+
_fp32_norm: torch.jit.Final[bool]
111+
112+
def __init__(
113+
self,
114+
num_channels: int,
115+
eps: float = 1e-6,
116+
affine: bool = True,
117+
use_fp32: bool = False,
118+
**kwargs,
119+
):
120+
super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
80121
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
122+
self._fp32_norm = use_fp32
81123

82124
def forward(self, x: torch.Tensor) -> torch.Tensor:
83125
x = x.permute(0, 2, 3, 1)
84126
if self._fast_norm:
85127
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
128+
elif self._fp32_norm:
129+
x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype)
86130
else:
87131
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
88132
x = x.permute(0, 3, 1, 2)
@@ -121,7 +165,7 @@ class LayerNormExp2d(nn.LayerNorm):
121165
layout. However, benefits are not always clear and can perform worse on other GPUs.
122166
"""
123167

124-
def __init__(self, num_channels, eps=1e-6):
168+
def __init__(self, num_channels: int, eps: float = 1e-6):
125169
super().__init__(num_channels, eps=eps)
126170

127171
def forward(self, x) -> torch.Tensor:
@@ -136,13 +180,22 @@ def forward(self, x) -> torch.Tensor:
136180
class RmsNorm(nn.Module):
137181
""" RmsNorm w/ fast (apex) norm if available
138182
"""
139-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
183+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm']
140184
normalized_shape: Tuple[int, ...]
141185
eps: float
142186
elementwise_affine: bool
143187
_fast_norm: bool
144-
145-
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
188+
_fp32_norm: bool
189+
190+
def __init__(
191+
self,
192+
channels: int,
193+
eps: float = 1e-6,
194+
affine: bool = True,
195+
use_fp32: bool = False,
196+
device=None,
197+
dtype=None,
198+
) -> None:
146199
factory_kwargs = {'device': device, 'dtype': dtype}
147200
super().__init__()
148201
normalized_shape = channels
@@ -153,6 +206,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
153206
self.eps = eps
154207
self.elementwise_affine = affine
155208
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
209+
self._fp32_norm = use_fp32
156210

157211
if self.elementwise_affine:
158212
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
@@ -167,9 +221,11 @@ def reset_parameters(self) -> None:
167221

168222
def forward(self, x: torch.Tensor) -> torch.Tensor:
169223
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
170-
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
224+
# Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed.
171225
if self._fast_norm:
172226
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
227+
elif self._fp32_norm:
228+
x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
173229
else:
174230
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
175231
return x
@@ -182,13 +238,22 @@ class RmsNorm2d(nn.Module):
182238
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
183239
like https://github.com/pytorch/pytorch/pull/150576 lands.
184240
"""
185-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
241+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm']
186242
normalized_shape: Tuple[int, ...]
187243
eps: float
188244
elementwise_affine: bool
189245
_fast_norm: bool
190-
191-
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
246+
_fp32_norm: bool
247+
248+
def __init__(
249+
self,
250+
channels: int,
251+
eps: float = 1e-6,
252+
affine: bool = True,
253+
use_fp32: bool = False,
254+
device=None,
255+
dtype=None,
256+
) -> None:
192257
factory_kwargs = {'device': device, 'dtype': dtype}
193258
super().__init__()
194259
normalized_shape = channels
@@ -199,6 +264,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
199264
self.eps = eps
200265
self.elementwise_affine = affine
201266
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
267+
self._fp32_norm = use_fp32
202268

203269
if self.elementwise_affine:
204270
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
@@ -216,6 +282,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
216282
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
217283
if self._fast_norm:
218284
x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
285+
elif self._fp32_norm:
286+
x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
219287
else:
220288
x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
221289
return x
@@ -224,13 +292,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
224292
class SimpleNorm(nn.Module):
225293
""" SimpleNorm (x / std(x))
226294
"""
227-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
295+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm']
228296
normalized_shape: Tuple[int, ...]
229297
eps: float
230298
elementwise_affine: bool
231299
_fast_norm: bool
232-
233-
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
300+
_fp32_norm: bool
301+
302+
def __init__(
303+
self,
304+
channels: int,
305+
eps: float = 1e-6,
306+
affine: bool = True,
307+
use_fp32: bool = False,
308+
device=None,
309+
dtype=None,
310+
) -> None:
234311
factory_kwargs = {'device': device, 'dtype': dtype}
235312
super().__init__()
236313
normalized_shape = channels
@@ -241,6 +318,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
241318
self.eps = eps
242319
self.elementwise_affine = affine
243320
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
321+
self._fp32_norm = use_fp32
244322

245323
if self.elementwise_affine:
246324
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
@@ -256,6 +334,8 @@ def reset_parameters(self) -> None:
256334
def forward(self, x: torch.Tensor) -> torch.Tensor:
257335
if self._fast_norm:
258336
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
337+
elif self._fp32_norm:
338+
x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
259339
else:
260340
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
261341
return x
@@ -264,13 +344,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
264344
class SimpleNorm2d(nn.Module):
265345
""" SimpleNorm for NCHW tensors
266346
"""
267-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
347+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm', '_fp32_norm']
268348
normalized_shape: Tuple[int, ...]
269349
eps: float
270350
elementwise_affine: bool
271351
_fast_norm: bool
272-
273-
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
352+
_fp32_norm: bool
353+
354+
def __init__(
355+
self,
356+
channels: int,
357+
eps: float = 1e-6,
358+
affine: bool = True,
359+
use_fp32: bool = False,
360+
device=None,
361+
dtype=None,
362+
) -> None:
274363
factory_kwargs = {'device': device, 'dtype': dtype}
275364
super().__init__()
276365
normalized_shape = channels
@@ -281,6 +370,7 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
281370
self.eps = eps
282371
self.elementwise_affine = affine
283372
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
373+
self._fp32_norm = use_fp32
284374

285375
if self.elementwise_affine:
286376
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
@@ -297,6 +387,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
297387
x = x.permute(0, 2, 3, 1)
298388
if self._fast_norm:
299389
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
390+
elif self._fp32_norm:
391+
x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
300392
else:
301393
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
302394
x = x.permute(0, 3, 1, 2)

0 commit comments

Comments
 (0)