|
13 | 13 | import torch.nn as nn |
14 | 14 | import torch.nn.functional as F |
15 | 15 |
|
16 | | -from ....utils import is_fla_available |
| 16 | +from ....utils import divide_if_divisible, is_fla_available |
17 | 17 | from ...cache import GenerationCache |
18 | 18 | from ...parameter import mark_parameter_as_no_weight_decay |
19 | 19 | from ..convolution import ParameterizedConv1d |
|
29 | 29 | class GatedDeltaNet(nn.Module): |
30 | 30 | def __init__( |
31 | 31 | self, |
32 | | - hidden_size: int = 2048, |
33 | | - head_dim: int = 256, |
34 | | - v_head_dim: int = 512, |
35 | | - num_heads: int = 6, |
36 | | - num_v_heads: int = None, |
37 | | - mode: str = "chunk", |
38 | | - use_gate: bool = True, |
39 | | - allow_neg_eigval: bool = False, |
40 | | - conv_size: int = 4, |
41 | | - layer_idx: int = None, |
42 | | - norm_eps: float = 1e-5, |
43 | | - init_method: str = "normal", |
44 | | - initializer_range: float = 0.02, |
45 | | - num_layers: int = 24, |
46 | | - use_padding_free_transformer: bool = False, |
| 32 | + hidden_size, |
| 33 | + k_head_dim, |
| 34 | + v_head_dim, |
| 35 | + num_k_heads, |
| 36 | + num_v_heads, |
| 37 | + use_gate: bool, |
| 38 | + allow_neg_eigval, |
| 39 | + conv_size: int, |
| 40 | + layer_idx: int, |
| 41 | + norm_eps: float, |
| 42 | + init_method: str, |
| 43 | + initializer_range: float, |
| 44 | + num_layers: int, |
| 45 | + use_padding_free_transformer: bool, |
47 | 46 | ) -> GatedDeltaNet: |
48 | 47 | super().__init__() |
49 | 48 |
|
50 | | - self.mode = mode |
| 49 | + assert not use_padding_free_transformer |
| 50 | + |
| 51 | + self.mode = "chunk" |
51 | 52 | self.allow_neg_eigval = allow_neg_eigval |
52 | 53 | self.hidden_size = hidden_size |
53 | 54 |
|
54 | 55 | self.use_gate = use_gate |
55 | 56 | self.conv_size = conv_size |
56 | 57 |
|
57 | | - self.head_dim = head_dim |
58 | | - self.num_heads = num_heads |
| 58 | + self.num_k_heads = num_k_heads |
59 | 59 | self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads |
60 | 60 |
|
61 | | - self.k_head_dim = head_dim |
| 61 | + self.k_head_dim = k_head_dim |
62 | 62 | self.v_head_dim = v_head_dim |
63 | | - self.key_dim = int(self.num_heads * self.k_head_dim) |
64 | | - self.value_dim = int(self.num_v_heads * self.v_head_dim) |
| 63 | + |
| 64 | + self.key_dim = self.num_k_heads * self.k_head_dim |
| 65 | + self.value_dim = self.num_v_heads * self.v_head_dim |
65 | 66 | self.layer_idx = layer_idx |
66 | 67 |
|
67 | | - if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0: |
68 | | - raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.") |
| 68 | + divide_if_divisible(self.num_v_heads, self.num_k_heads) |
69 | 69 |
|
70 | 70 | assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`." |
71 | 71 |
|
@@ -170,9 +170,9 @@ def forward( |
170 | 170 | k = k.view(*q_size[:-1], -1, self.k_head_dim) |
171 | 171 | v = v.view(*v.size()[:-1], -1, self.v_head_dim) |
172 | 172 |
|
173 | | - if self.num_v_heads > self.num_heads: |
174 | | - q = q.repeat_interleave(repeats=self.num_v_heads // self.num_heads, dim=-2) |
175 | | - k = k.repeat_interleave(repeats=self.num_v_heads // self.num_heads, dim=-2) |
| 173 | + if self.num_v_heads > self.num_k_heads: |
| 174 | + q = q.repeat_interleave(repeats=self.num_v_heads // self.num_k_heads, dim=-2) |
| 175 | + k = k.repeat_interleave(repeats=self.num_v_heads // self.num_k_heads, dim=-2) |
176 | 176 |
|
177 | 177 | beta = b.sigmoid() |
178 | 178 | if self.allow_neg_eigval: |
|
0 commit comments