File tree Expand file tree Collapse file tree 2 files changed +16
-1
lines changed
Expand file tree Collapse file tree 2 files changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -41,10 +41,13 @@ class MoEMergeConfig(BaseModel):
4141
4242 base_model : ModelReference
4343 experts : List [Expert ]
44- gate_mode : str = "hidden" # possible values: "hidden", "cheap_embed", "random"
44+ gate_mode : str = (
45+ "hidden" # possible values: "hidden", "cheap_embed", "random", "uniform_random"
46+ )
4547 # "hidden" uses hidden state vectors for the given prompts for each layer
4648 # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
4749 # "random" is random
50+ # "uniform_random" matches default initialization for torch.nn.Linear
4851 dtype : Optional [str ] = None
4952 experts_per_token : int = 2
5053 shared_experts : Optional [List [Expert ]] = None
Original file line number Diff line number Diff line change 1414# along with this program. If not, see http://www.gnu.org/licenses/.
1515
1616import logging
17+ import math
1718from typing import Dict , List , Union
1819
1920import torch
@@ -99,6 +100,17 @@ def get_gate_params(
99100 return torch .randn (
100101 (model_cfg .num_hidden_layers , len (experts ), model_cfg .hidden_size )
101102 )
103+ elif mode == "uniform_random" :
104+ in_features = model_cfg .hidden_size
105+ scale = math .sqrt (1.0 / in_features )
106+ return (
107+ torch .rand (
108+ (model_cfg .num_hidden_layers , len (experts ), model_cfg .hidden_size )
109+ )
110+ * 2
111+ * scale
112+ - scale
113+ )
102114 elif mode == "cheap_embed" :
103115 embed = model_ref .lazy_loader (lazy_unpickle = lazy_unpickle ).get_tensor (
104116 "model.embed_tokens.weight"
You can’t perform that action at this time.
0 commit comments