diff --git a/litgpt/config.py b/litgpt/config.py index 97549a114d..bd3b7776ec 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2112,6 +2112,46 @@ def norm_class(self) -> Type: intermediate_size=28672, ) ) +configs.append( + # https://huggingface.co/mistralai/Mistral-Small-24B-Base-2501/blob/main/config.json + dict( + name="Mistral-Small-24B-Base-2501", + hf_config=dict(org="mistralai", name="Mistral-Small-24B-Base-2501"), + padded_vocab_size=131072, + block_size=131072, + n_layer=40, + n_head=32, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-05, + mlp_class_name="LLaMAMLP", + intermediate_size=32768, + ) +) +configs.append( + # https://huggingface.co/mistralai/Mistral-Small-24B-Instruct-2501/blob/main/config.json + dict( + name="Mistral-Small-24B-Instruct-2501", + hf_config=dict(org="mistralai", name="Mistral-Small-24B-Instruct-2501"), + padded_vocab_size=131072, + block_size=131072, + n_layer=40, + n_head=32, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-05, + mlp_class_name="LLaMAMLP", + intermediate_size=32768, + ) +) ############