Skip to content

Commit 6790c21

Browse files
committed
fla
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 53ec521 commit 6790c21

File tree

3 files changed

+32
-34
lines changed

3 files changed

+32
-34
lines changed

lm_engine/hf_models/config/sequence_mixer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,13 @@ def model_post_init(self, __context: Any) -> None:
116116

117117
class _GatedDeltaNetArgs(BaseArgs):
118118
sequence_mixer_type: str = "gated_deltanet"
119-
head_dim: int = 256
119+
k_head_dim: int = 256
120120
v_head_dim: int = 512
121-
num_heads: int = 6
121+
num_k_heads: int = 6
122122
num_v_heads: int = 6
123-
mode: str = "chunk"
124123
use_gate: bool = True
125124
allow_neg_eigval: bool = False
126-
conv_size: int = 4
125+
kernel_size: int = 4
127126

128127
def model_post_init(self, __context: Any) -> None:
129128
assert self.sequence_mixer_type == "gated_deltanet"

lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,10 @@ def get_sequence_mixer(
134134
elif sequence_mixer_type == "gated_deltanet":
135135
return GatedDeltaNet(
136136
hidden_size=config.hidden_size,
137-
head_dim=block.head_dim,
137+
k_head_dim=block.k_head_dim,
138138
v_head_dim=block.v_head_dim,
139-
num_heads=block.num_heads,
139+
num_k_heads=block.num_k_heads,
140140
num_v_heads=block.num_v_heads,
141-
mode=block.mode,
142141
use_gate=block.use_gate,
143142
allow_neg_eigval=block.allow_neg_eigval,
144143
conv_size=block.conv_size,

lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414
import torch.nn.functional as F
1515

16-
from ....utils import is_fla_available
16+
from ....utils import divide_if_divisible, is_fla_available
1717
from ...cache import GenerationCache
1818
from ...parameter import mark_parameter_as_no_weight_decay
1919
from ..convolution import ParameterizedConv1d
@@ -29,43 +29,43 @@
2929
class GatedDeltaNet(nn.Module):
3030
def __init__(
3131
self,
32-
hidden_size: int = 2048,
33-
head_dim: int = 256,
34-
v_head_dim: int = 512,
35-
num_heads: int = 6,
36-
num_v_heads: int = None,
37-
mode: str = "chunk",
38-
use_gate: bool = True,
39-
allow_neg_eigval: bool = False,
40-
conv_size: int = 4,
41-
layer_idx: int = None,
42-
norm_eps: float = 1e-5,
43-
init_method: str = "normal",
44-
initializer_range: float = 0.02,
45-
num_layers: int = 24,
46-
use_padding_free_transformer: bool = False,
32+
hidden_size,
33+
k_head_dim,
34+
v_head_dim,
35+
num_k_heads,
36+
num_v_heads,
37+
use_gate: bool,
38+
allow_neg_eigval,
39+
conv_size: int,
40+
layer_idx: int,
41+
norm_eps: float,
42+
init_method: str,
43+
initializer_range: float,
44+
num_layers: int,
45+
use_padding_free_transformer: bool,
4746
) -> GatedDeltaNet:
4847
super().__init__()
4948

50-
self.mode = mode
49+
assert not use_padding_free_transformer
50+
51+
self.mode = "chunk"
5152
self.allow_neg_eigval = allow_neg_eigval
5253
self.hidden_size = hidden_size
5354

5455
self.use_gate = use_gate
5556
self.conv_size = conv_size
5657

57-
self.head_dim = head_dim
58-
self.num_heads = num_heads
58+
self.num_k_heads = num_k_heads
5959
self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads
6060

61-
self.k_head_dim = head_dim
61+
self.k_head_dim = k_head_dim
6262
self.v_head_dim = v_head_dim
63-
self.key_dim = int(self.num_heads * self.k_head_dim)
64-
self.value_dim = int(self.num_v_heads * self.v_head_dim)
63+
64+
self.key_dim = self.num_k_heads * self.k_head_dim
65+
self.value_dim = self.num_v_heads * self.v_head_dim
6566
self.layer_idx = layer_idx
6667

67-
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
68-
raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")
68+
divide_if_divisible(self.num_v_heads, self.num_k_heads)
6969

7070
assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
7171

@@ -170,9 +170,9 @@ def forward(
170170
k = k.view(*q_size[:-1], -1, self.k_head_dim)
171171
v = v.view(*v.size()[:-1], -1, self.v_head_dim)
172172

173-
if self.num_v_heads > self.num_heads:
174-
q = q.repeat_interleave(repeats=self.num_v_heads // self.num_heads, dim=-2)
175-
k = k.repeat_interleave(repeats=self.num_v_heads // self.num_heads, dim=-2)
173+
if self.num_v_heads > self.num_k_heads:
174+
q = q.repeat_interleave(repeats=self.num_v_heads // self.num_k_heads, dim=-2)
175+
k = k.repeat_interleave(repeats=self.num_v_heads // self.num_k_heads, dim=-2)
176176

177177
beta = b.sigmoid()
178178
if self.allow_neg_eigval:

0 commit comments

Comments
 (0)