diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..e7c60358f 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -47,6 +47,7 @@ ) OFFICIAL_MODEL_NAMES = [ + "openai/gpt-oss-20b", "gpt2", "gpt2-medium", "gpt2-large", @@ -268,6 +269,7 @@ # Model Aliases: MODEL_ALIASES = { + "openai/gpt-oss-20b": ["gpt-oss-20b", "gpt-oss"], "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], @@ -794,6 +796,12 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): architecture = "Gemma2ForCausalLM" elif "gemma" in official_model_name.lower(): architecture = "GemmaForCausalLM" + # elif architecture in ("GptOssForCausalLM", "GPTOssForCausalLM"): + # from transformer_lens.factories.architecture_adapter_factory import ( + # ArchitectureAdapterFactory, + # ) + # adapter = ArchitectureAdapterFactory().from_hf_config(hf_config) + # cfg_dict = adapter.make_cfg_dict(hf_config) else: huggingface_token = os.environ.get("HF_TOKEN", "") hf_config = AutoConfig.from_pretrained( @@ -802,6 +810,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): **kwargs, ) architecture = hf_config.architectures[0] + cfg_dict: dict[str, Any] if official_model_name.startswith( @@ -824,6 +833,41 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "final_rms": True, "gated_mlp": True, } + + elif official_model_name.startswith( + ("gpt-oss-20b", "openai/gpt-oss-20b") + ): # architecture for gpt-oss + cfg_dict = { + "d_model": 2880, + "d_head": 64, + "n_heads": 64, + "d_mlp": 2880, + "n_layers": 24, + "n_ctx": 131072, + "eps": 1e-5, + "d_vocab": 201088, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 64, + "rotary_base": 150000, + "num_experts": 32, + "experts_per_token": 4, + + + "gated_mlp": True, # SWiGLU + "use_local_attn": True, # sliding window attention + "window_size": 128, # sliding_window + "attn_types": [ # 来自 layer_types(24 层交替) + "local","global","local","global","local","global", + "local","global","local","global","local","global", + "local","global","local","global","local","global", + "local","global","local","global","local","global", + ], + } + elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 cfg_dict = { "d_model": 4096, @@ -838,9 +882,9 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "normalization_type": "RMS", "positional_embedding_type": "rotary", "rotary_dim": 4096 // 32, - "final_rms": True, - "gated_mlp": True, - "rotary_base": 1000000, + "n_key_value_heads": 8, + "num_experts": 32, + "experts_per_token": 4, } if "python" in official_model_name.lower(): # The vocab size of python version of CodeLlama-7b is 32000 @@ -1085,6 +1129,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "NTK_by_parts_factor": 8.0, "NTK_original_ctx_len": 8192, } + elif architecture == "GPTNeoForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -1986,10 +2031,31 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) - else: - raise ValueError( - f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." + # ---------- NEW: adapter fallback for new / custom HF architectures ---------- + elif cfg.original_architecture in ("GptOssForCausalLM", "GPTOssForCausalLM"): + # 延迟导入以避免循环依赖 + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, ) + adapter = ArchitectureAdapterFactory().from_hf_config(hf_model.config) + state_dict = adapter.to_transformer_lens_state_dict(hf_model, cfg) + else: + # 兜底:尝试用 adapter 支持更多未来架构 + try: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + adapter = ArchitectureAdapterFactory().from_hf_config(hf_model.config) + state_dict = adapter.to_transformer_lens_state_dict(hf_model, cfg) + except Exception: + raise ValueError( + f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." + ) + + # else: + # raise ValueError( + # f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." + # ) return state_dict