Skip to content

Commit 8b155a2

Browse files
committed
init
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 0260b71 commit 8b155a2

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ def get_sequence_mixer(
140140
conv_size=block.conv_size,
141141
layer_idx=layer_idx,
142142
norm_eps=config.layer_norm_epsilon,
143+
init_method=config.init_method,
144+
initializer_range=config.initializer_range,
145+
num_layers=config.num_layers,
143146
use_padding_free_transformer=use_padding_free_transformer,
144147
)
145148
else:

lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
self.qkv_proj = ParameterizedLinear(hidden_size, 2 * self.key_dim + self.value_dim, bias=False, std=std)
128128

129129
self.ab_proj = ParameterizedLinear(
130-
hidden_size, 2 * self.num_v_heads + (self.value_dim if use_gate else 0), bias=False
130+
hidden_size, 2 * self.num_v_heads + (self.value_dim if use_gate else 0), bias=False, std=std
131131
)
132132

133133
A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16)
@@ -154,12 +154,14 @@ def __init__(
154154
padding=conv_size - 1,
155155
groups=2 * self.key_dim + self.value_dim,
156156
bias=False,
157-
std=None, # TODO
157+
std=std, # TODO
158158
)
159159
self.activation_string = "silu"
160160

161+
std = initializer_range / math.sqrt(2 * num_layers)
162+
161163
self.o_norm = get_normalization_function("rmsnorm", self.v_head_dim, eps=norm_eps)
162-
self.o_proj = ParameterizedLinear(self.value_dim, hidden_size, bias=False)
164+
self.o_proj = ParameterizedLinear(self.value_dim, hidden_size, bias=False, std=std)
163165

164166
def forward(
165167
self,

0 commit comments

Comments
 (0)