Skip to content
Open

update1 #1038

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)

OFFICIAL_MODEL_NAMES = [
"openai/gpt-oss-20b",
"gpt2",
"gpt2-medium",
"gpt2-large",
Expand Down Expand Up @@ -268,6 +269,7 @@

# Model Aliases:
MODEL_ALIASES = {
"openai/gpt-oss-20b": ["gpt-oss-20b", "gpt-oss"],
"NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
"NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
"NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
Expand Down Expand Up @@ -794,6 +796,12 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
architecture = "Gemma2ForCausalLM"
elif "gemma" in official_model_name.lower():
architecture = "GemmaForCausalLM"
# elif architecture in ("GptOssForCausalLM", "GPTOssForCausalLM"):
# from transformer_lens.factories.architecture_adapter_factory import (
# ArchitectureAdapterFactory,
# )
# adapter = ArchitectureAdapterFactory().from_hf_config(hf_config)
# cfg_dict = adapter.make_cfg_dict(hf_config)
else:
huggingface_token = os.environ.get("HF_TOKEN", "")
hf_config = AutoConfig.from_pretrained(
Expand All @@ -802,6 +810,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
**kwargs,
)
architecture = hf_config.architectures[0]


cfg_dict: dict[str, Any]
if official_model_name.startswith(
Expand All @@ -824,6 +833,41 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"final_rms": True,
"gated_mlp": True,
}

elif official_model_name.startswith(
("gpt-oss-20b", "openai/gpt-oss-20b")
): # architecture for gpt-oss
cfg_dict = {
"d_model": 2880,
"d_head": 64,
"n_heads": 64,
"d_mlp": 2880,
"n_layers": 24,
"n_ctx": 131072,
"eps": 1e-5,
"d_vocab": 201088,
"act_fn": "silu",
"n_key_value_heads": 8,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": 64,
"rotary_base": 150000,
"num_experts": 32,
"experts_per_token": 4,


"gated_mlp": True, # SWiGLU
"use_local_attn": True, # sliding window attention
"window_size": 128, # sliding_window
"attn_types": [ # 来自 layer_types(24 层交替)
"local","global","local","global","local","global",
"local","global","local","global","local","global",
"local","global","local","global","local","global",
"local","global","local","global","local","global",
],
}

elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2
cfg_dict = {
"d_model": 4096,
Expand All @@ -838,9 +882,9 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_dim": 4096 // 32,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 1000000,
"n_key_value_heads": 8,
"num_experts": 32,
"experts_per_token": 4,
}
if "python" in official_model_name.lower():
# The vocab size of python version of CodeLlama-7b is 32000
Expand Down Expand Up @@ -1085,6 +1129,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"NTK_by_parts_factor": 8.0,
"NTK_original_ctx_len": 8192,
}

elif architecture == "GPTNeoForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
Expand Down Expand Up @@ -1986,10 +2031,31 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
# ---------- NEW: adapter fallback for new / custom HF architectures ----------
elif cfg.original_architecture in ("GptOssForCausalLM", "GPTOssForCausalLM"):
# 延迟导入以避免循环依赖
from transformer_lens.factories.architecture_adapter_factory import (
ArchitectureAdapterFactory,
)
adapter = ArchitectureAdapterFactory().from_hf_config(hf_model.config)
state_dict = adapter.to_transformer_lens_state_dict(hf_model, cfg)
else:
# 兜底:尝试用 adapter 支持更多未来架构
try:
from transformer_lens.factories.architecture_adapter_factory import (
ArchitectureAdapterFactory,
)
adapter = ArchitectureAdapterFactory().from_hf_config(hf_model.config)
state_dict = adapter.to_transformer_lens_state_dict(hf_model, cfg)
except Exception:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
)

# else:
# raise ValueError(
# f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
# )

return state_dict

Expand Down