|
4 | 4 | """ |
5 | 5 | Praxis Modules |
6 | 6 | """ |
| 7 | +from dataclasses import field |
7 | 8 | from functools import partial |
8 | 9 | from typing import Callable, Iterable, Sequence, Tuple, Union |
9 | 10 |
|
@@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): |
74 | 75 | zero_centered_gamma: bool = False |
75 | 76 | scale_init: WeightInit = None |
76 | 77 | scale_axes: Tuple[str, ...] = () |
77 | | - bias_init: WeightInit = WeightInit.Constant(0.0) |
| 78 | + bias_init: WeightInit = field( # pylint: disable=invalid-field-call |
| 79 | + default_factory=partial(WeightInit.Constant, scale=0.0) |
| 80 | + ) |
78 | 81 | bias_axes: Tuple[str, ...] = () |
79 | 82 | transpose_batch_sequence: bool = False |
80 | 83 |
|
@@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): |
129 | 132 | out_features: int = 512 |
130 | 133 | kernel_axes: Tuple[str, ...] = () |
131 | 134 | use_bias: bool = True |
132 | | - bias_init: WeightInit = WeightInit.Constant(0.0) |
| 135 | + bias_init: WeightInit = field( # pylint: disable=invalid-field-call |
| 136 | + default_factory=partial(WeightInit.Constant, scale=0.0) |
| 137 | + ) |
133 | 138 | bias_axes: Tuple[str, ...] = () |
134 | 139 | enable_low_rank_adaptation: bool = False |
135 | 140 | low_rank_adaptation_dim: int = 32 |
@@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): |
174 | 179 | zero_centered_gamma: bool = False |
175 | 180 | scale_init: WeightInit = None |
176 | 181 | scale_axes: Tuple[str, ...] = () |
177 | | - ln_bias_init: WeightInit = WeightInit.Constant(1.0) |
| 182 | + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call |
| 183 | + default_factory=partial(WeightInit.Constant, scale=1.0) |
| 184 | + ) |
178 | 185 | ln_bias_axes: Tuple[str, ...] = () |
179 | 186 | kernel_axes: Tuple[str, ...] = () |
180 | 187 | use_bias: bool = False |
181 | | - bias_init: WeightInit = WeightInit.Constant(0.0) |
| 188 | + bias_init: WeightInit = field( # pylint: disable=invalid-field-call |
| 189 | + default_factory=partial(WeightInit.Constant, scale=0.0) |
| 190 | + ) |
182 | 191 | bias_axes: Tuple[str, ...] = () |
183 | 192 | enable_low_rank_adaptation: bool = False |
184 | 193 | low_rank_adaptation_dim: int = 32 |
@@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): |
237 | 246 | zero_centered_gamma: bool = False |
238 | 247 | scale_init: WeightInit = None |
239 | 248 | scale_axes: Tuple[str, ...] = () |
240 | | - ln_bias_init: WeightInit = WeightInit.Constant(1.0) |
| 249 | + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call |
| 250 | + default_factory=partial(WeightInit.Constant, scale=1.0) |
| 251 | + ) |
241 | 252 | ln_bias_axes: Tuple[str, ...] = () |
242 | 253 | kernel_axes_1: Tuple[str, ...] = () |
243 | 254 | kernel_axes_2: Tuple[str, ...] = () |
244 | 255 | use_bias: bool = False |
245 | | - bias_init: WeightInit = WeightInit.Constant(0.0) |
| 256 | + bias_init: WeightInit = field( # pylint: disable=invalid-field-call |
| 257 | + default_factory=partial(WeightInit.Constant, scale=0.0) |
| 258 | + ) |
246 | 259 | bias_axes_1: Tuple[str, ...] = () |
247 | 260 | bias_axes_2: Tuple[str, ...] = () |
248 | 261 | enable_low_rank_adaptation: bool = False |
|
0 commit comments