22import typing
33
44from fast_llm .config import Field , FieldHint , check_field , config_class , skip_valid_if_none
5- from fast_llm .engine .base_model .config import ModuleConfig
65from fast_llm .engine .config_utils .parameter import OptionalParameterConfig , ParameterConfig , combine_lr_scales
76from fast_llm .engine .config_utils .tensor_dim import TensorDim
87from fast_llm .engine .distributed .config import DistributedConfig
1615if typing .TYPE_CHECKING :
1716 from fast_llm .layers .language_model .embedding import LanguageModelEmbedding
1817 from fast_llm .layers .language_model .head import LanguageModelHead , LanguageModelHeadBase
18+ from fast_llm .layers .language_model .language_model import LanguageModel
1919 from fast_llm .layers .language_model .multi_token_prediction import MultiTokenPrediction
2020
2121
@@ -41,12 +41,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig):
4141 desc = "Configuration for the word embedding (weight)." ,
4242 hint = FieldHint .architecture ,
4343 )
44- hidden_size : int = Field (
45- default = 1024 ,
46- desc = "Size of the model's main hidden dimension, e.g., for its input and output layers." ,
47- hint = FieldHint .architecture ,
48- valid = check_field (Assert .gt , 0 ),
49- )
5044 vocab_size : int = Field (
5145 default = 49152 ,
5246 desc = "Size of the vocabulary, i.e., number of vocabulary embeddings and logits." ,
@@ -295,24 +289,29 @@ def max_prediction_distance(self) -> int:
295289
296290
297291@config_class ()
298- class LanguageModelConfig (ModuleConfig ):
299- # TODO: block
292+ class LanguageModelConfig (BlockConfig ):
300293 decoder : BlockSequenceConfig = Field (
301294 desc = "Configuration for the language model decoder." ,
302295 hint = FieldHint .architecture ,
303296 )
304- embeddings : LanguageModelEmbeddingsConfig = Field ()
305- head : LanguageModelHeadBaseConfig = Field ()
306- # TODO: Allow overriding in sub-models?
307- peft : PeftConfig = Field (
308- desc = "Configuration for parameter-efficient fine tuning." ,
297+ embeddings : LanguageModelEmbeddingsConfig = Field (
309298 hint = FieldHint .architecture ,
299+ desc = "Configuration for the language model embeddings." ,
300+ )
301+ head : LanguageModelHeadBaseConfig = Field (
302+ hint = FieldHint .architecture , desc = "Configuration for the language model head(s)."
310303 )
311304 tied_embedding_weight : bool = Field (
312305 default = False ,
313306 desc = "Tie the output weights (logits) with the vocabulary embedding." ,
314307 hint = FieldHint .architecture ,
315308 )
309+ hidden_size : int = Field (
310+ default = 1024 ,
311+ desc = "Size of the model's main hidden dimension, e.g., for its input and output layers." ,
312+ hint = FieldHint .architecture ,
313+ valid = check_field (Assert .gt , 0 ),
314+ )
316315 sequence_first : bool | None = Field (
317316 default = None ,
318317 desc = "Override the default dimension ordering" ,
@@ -321,3 +320,9 @@ class LanguageModelConfig(ModuleConfig):
321320 " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error." ,
322321 hint = FieldHint .testing ,
323322 )
323+
324+ @property
325+ def layer_class (self ) -> "type[LanguageModel]" :
326+ from fast_llm .layers .language_model .language_model import LanguageModel
327+
328+ return LanguageModel
0 commit comments