Skip to content

Commit ca96e86

Browse files
authored
Add uniform_random gate mode to mergekit-moe (#303)
To better match initialization of `nn.Linear`.
1 parent 09c63e6 commit ca96e86

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

mergekit/moe/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ class MoEMergeConfig(BaseModel):
4141

4242
base_model: ModelReference
4343
experts: List[Expert]
44-
gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random"
44+
gate_mode: str = (
45+
"hidden" # possible values: "hidden", "cheap_embed", "random", "uniform_random"
46+
)
4547
# "hidden" uses hidden state vectors for the given prompts for each layer
4648
# "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
4749
# "random" is random
50+
# "uniform_random" matches default initialization for torch.nn.Linear
4851
dtype: Optional[str] = None
4952
experts_per_token: int = 2
5053
shared_experts: Optional[List[Expert]] = None

mergekit/moe/router.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# along with this program. If not, see http://www.gnu.org/licenses/.
1515

1616
import logging
17+
import math
1718
from typing import Dict, List, Union
1819

1920
import torch
@@ -99,6 +100,17 @@ def get_gate_params(
99100
return torch.randn(
100101
(model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
101102
)
103+
elif mode == "uniform_random":
104+
in_features = model_cfg.hidden_size
105+
scale = math.sqrt(1.0 / in_features)
106+
return (
107+
torch.rand(
108+
(model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
109+
)
110+
* 2
111+
* scale
112+
- scale
113+
)
102114
elif mode == "cheap_embed":
103115
embed = model_ref.lazy_loader(lazy_unpickle=lazy_unpickle).get_tensor(
104116
"model.embed_tokens.weight"

0 commit comments

Comments
 (0)