Skip to content

Commit e4c99b0

Browse files
zlsh80826phu0ngng
andauthored
[JAX] Use default factory for not sharing mutable default values (NVIDIA#1364)
* Bug Fix: Use default factory for not sharing mutable default values --------- Signed-off-by: Reese Wang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]>
1 parent 3102fdd commit e4c99b0

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

transformer_engine/jax/praxis/module.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
Praxis Modules
66
"""
7+
from dataclasses import field
78
from functools import partial
89
from typing import Callable, Iterable, Sequence, Tuple, Union
910

@@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer):
7475
zero_centered_gamma: bool = False
7576
scale_init: WeightInit = None
7677
scale_axes: Tuple[str, ...] = ()
77-
bias_init: WeightInit = WeightInit.Constant(0.0)
78+
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
79+
default_factory=partial(WeightInit.Constant, scale=0.0)
80+
)
7881
bias_axes: Tuple[str, ...] = ()
7982
transpose_batch_sequence: bool = False
8083

@@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer):
129132
out_features: int = 512
130133
kernel_axes: Tuple[str, ...] = ()
131134
use_bias: bool = True
132-
bias_init: WeightInit = WeightInit.Constant(0.0)
135+
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
136+
default_factory=partial(WeightInit.Constant, scale=0.0)
137+
)
133138
bias_axes: Tuple[str, ...] = ()
134139
enable_low_rank_adaptation: bool = False
135140
low_rank_adaptation_dim: int = 32
@@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
174179
zero_centered_gamma: bool = False
175180
scale_init: WeightInit = None
176181
scale_axes: Tuple[str, ...] = ()
177-
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
182+
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
183+
default_factory=partial(WeightInit.Constant, scale=1.0)
184+
)
178185
ln_bias_axes: Tuple[str, ...] = ()
179186
kernel_axes: Tuple[str, ...] = ()
180187
use_bias: bool = False
181-
bias_init: WeightInit = WeightInit.Constant(0.0)
188+
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
189+
default_factory=partial(WeightInit.Constant, scale=0.0)
190+
)
182191
bias_axes: Tuple[str, ...] = ()
183192
enable_low_rank_adaptation: bool = False
184193
low_rank_adaptation_dim: int = 32
@@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer):
237246
zero_centered_gamma: bool = False
238247
scale_init: WeightInit = None
239248
scale_axes: Tuple[str, ...] = ()
240-
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
249+
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
250+
default_factory=partial(WeightInit.Constant, scale=1.0)
251+
)
241252
ln_bias_axes: Tuple[str, ...] = ()
242253
kernel_axes_1: Tuple[str, ...] = ()
243254
kernel_axes_2: Tuple[str, ...] = ()
244255
use_bias: bool = False
245-
bias_init: WeightInit = WeightInit.Constant(0.0)
256+
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
257+
default_factory=partial(WeightInit.Constant, scale=0.0)
258+
)
246259
bias_axes_1: Tuple[str, ...] = ()
247260
bias_axes_2: Tuple[str, ...] = ()
248261
enable_low_rank_adaptation: bool = False

transformer_engine/jax/praxis/transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
Praxis Modules related Transformer
66
"""
7+
from dataclasses import field
78
from functools import partial
89
from typing import Optional, Sequence, Tuple
910
import warnings
@@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
138139
zero_centered_gamma: bool = False
139140
return_layernorm_output: bool = False
140141
use_bias: bool = False
141-
bias_init: WeightInit = WeightInit.Constant(0.0)
142+
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
143+
default_factory=partial(WeightInit.Constant, scale=0.0)
144+
)
142145
attn_mask_type: str = "causal"
143146
attn_bias_type: Optional[str] = None
144147
enable_rotary_pos_emb: bool = False
@@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
275278
dropout_rng_name: str = "dropout"
276279
mlp_activations: Sequence[str] = ("relu",)
277280
use_bias: bool = False
278-
bias_init: WeightInit = WeightInit.Constant(0.0)
281+
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
282+
default_factory=partial(WeightInit.Constant, scale=0.0)
283+
)
279284
apply_residual_connection_post_layernorm: bool = False
280285
output_layernorm: bool = False
281286
float32_attention_logits: bool = False

0 commit comments

Comments
 (0)