Skip to content

Commit 2446a17

Browse files
authored
Language model block (#372)
1 parent 6c82d99 commit 2446a17

File tree

14 files changed

+86
-68
lines changed

14 files changed

+86
-68
lines changed

Dockerfile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,8 @@ COPY --chmod=777 ./tests tests
4747
COPY --chmod=777 ./tools tools
4848
COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models
4949
COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/
50+
51+
# Set a dummy default user so we don't run in root by default.
52+
# The image is still compatible with any user id.
53+
RUN useradd user
54+
USER user

examples/mistral.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ optimizer:
2828
model:
2929
base_model:
3030
embeddings:
31-
hidden_size: 4096
3231
vocab_size: 32000
3332
dropout: 0.0
3433
decoder:
@@ -58,6 +57,7 @@ model:
5857
normalization:
5958
type: rms_norm
6059
epsilon: 1.0e-05
60+
hidden_size: 4096
6161
tied_embedding_weight: false
6262
multi_stage:
6363
zero_stage: 2

fast_llm/engine/checkpoint/huggingface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def _load_weights(
150150
].values()
151151
}
152152
elif (config.path / transformers.utils.WEIGHTS_NAME).is_file():
153-
# TODO: Prevent unsafe by default
154153
paths = {config.path / transformers.utils.WEIGHTS_NAME}
155154
elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file():
156155
logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}")
@@ -170,7 +169,7 @@ def _load_weights(
170169
for key in f.keys():
171170
yield key, "weights", f.get_slice(key)
172171
elif path.suffix == ".bin":
173-
# TODO: Prevent unsafe by default
174-
yield from torch.load(path)
172+
# TODO: Confirm that loading works with `weights_only=True`
173+
yield from torch.load(path, weights_only=True)
175174
else:
176175
raise NotImplementedError(f"Unknown file format for {path}")

fast_llm/engine/config_utils/parameter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
22
import typing
33

4-
from fast_llm.config import Config, Field, FieldHint, config_class
4+
from fast_llm.config import Field, FieldHint, config_class
5+
from fast_llm.engine.base_model.config import ModuleConfig
56
from fast_llm.engine.config_utils.initialization import Initialization, InitializationConfig
67
from fast_llm.engine.config_utils.tensor_dim import TensorDim
78
from fast_llm.layers.common.peft.config import PeftConfig
@@ -36,7 +37,7 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]):
3637

3738

3839
@config_class()
39-
class ParameterConfig(Config):
40+
class ParameterConfig(ModuleConfig):
4041
initialization: InitializationConfig = Field(
4142
desc="If provided, override the default initialization method set by the parent layer.",
4243
hint=FieldHint.feature,

fast_llm/layers/block/block.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
config: ConfigType,
104104
distributed_config: DistributedConfig,
105105
*,
106+
# TODO: Review. Use `input_dim(s)` and `output_dim(s)` instead?
106107
hidden_dim: TensorDim,
107108
lr_scale: float | None,
108109
peft: PeftConfig | None,

fast_llm/layers/language_model/config.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import typing
33

44
from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
5-
from fast_llm.engine.base_model.config import ModuleConfig
65
from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales
76
from fast_llm.engine.config_utils.tensor_dim import TensorDim
87
from fast_llm.engine.distributed.config import DistributedConfig
@@ -16,6 +15,7 @@
1615
if typing.TYPE_CHECKING:
1716
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
1817
from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase
18+
from fast_llm.layers.language_model.language_model import LanguageModel
1919
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction
2020

2121

@@ -41,12 +41,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig):
4141
desc="Configuration for the word embedding (weight).",
4242
hint=FieldHint.architecture,
4343
)
44-
hidden_size: int = Field(
45-
default=1024,
46-
desc="Size of the model's main hidden dimension, e.g., for its input and output layers.",
47-
hint=FieldHint.architecture,
48-
valid=check_field(Assert.gt, 0),
49-
)
5044
vocab_size: int = Field(
5145
default=49152,
5246
desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.",
@@ -295,24 +289,29 @@ def max_prediction_distance(self) -> int:
295289

296290

297291
@config_class()
298-
class LanguageModelConfig(ModuleConfig):
299-
# TODO: block
292+
class LanguageModelConfig(BlockConfig):
300293
decoder: BlockSequenceConfig = Field(
301294
desc="Configuration for the language model decoder.",
302295
hint=FieldHint.architecture,
303296
)
304-
embeddings: LanguageModelEmbeddingsConfig = Field()
305-
head: LanguageModelHeadBaseConfig = Field()
306-
# TODO: Allow overriding in sub-models?
307-
peft: PeftConfig = Field(
308-
desc="Configuration for parameter-efficient fine tuning.",
297+
embeddings: LanguageModelEmbeddingsConfig = Field(
309298
hint=FieldHint.architecture,
299+
desc="Configuration for the language model embeddings.",
300+
)
301+
head: LanguageModelHeadBaseConfig = Field(
302+
hint=FieldHint.architecture, desc="Configuration for the language model head(s)."
310303
)
311304
tied_embedding_weight: bool = Field(
312305
default=False,
313306
desc="Tie the output weights (logits) with the vocabulary embedding.",
314307
hint=FieldHint.architecture,
315308
)
309+
hidden_size: int = Field(
310+
default=1024,
311+
desc="Size of the model's main hidden dimension, e.g., for its input and output layers.",
312+
hint=FieldHint.architecture,
313+
valid=check_field(Assert.gt, 0),
314+
)
316315
sequence_first: bool | None = Field(
317316
default=None,
318317
desc="Override the default dimension ordering",
@@ -321,3 +320,9 @@ class LanguageModelConfig(ModuleConfig):
321320
" Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.",
322321
hint=FieldHint.testing,
323322
)
323+
324+
@property
325+
def layer_class(self) -> "type[LanguageModel]":
326+
from fast_llm.layers.language_model.language_model import LanguageModel
327+
328+
return LanguageModel

fast_llm/layers/language_model/language_model.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,64 @@
11
import logging
22
import typing
33

4-
from fast_llm.config import Configurable
5-
from fast_llm.engine.base_model.base_model import Layer, LayerBase
4+
import torch
5+
6+
from fast_llm.engine.base_model.base_model import Layer
67
from fast_llm.engine.base_model.config import LossDef
78
from fast_llm.engine.config_utils.tensor_dim import TensorDim
89
from fast_llm.engine.distributed.config import DistributedConfig
10+
from fast_llm.layers.block.block import BlockBase
11+
from fast_llm.layers.common.peft.config import PeftConfig
912
from fast_llm.layers.language_model.config import LanguageModelConfig
1013
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
1114

1215
logger = logging.getLogger(__name__)
1316

1417

15-
class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase):
18+
class LanguageModel[ConfigType: LanguageModelConfig](BlockBase[ConfigType]):
1619
_config: ConfigType
1720

1821
def __init__(
1922
self,
2023
config: ConfigType,
2124
distributed_config: DistributedConfig,
25+
*,
26+
# TODO: Unused, but required by the `BlockBase` interface.
27+
hidden_dim: TensorDim | None = None,
28+
lr_scale: float | None,
29+
peft: PeftConfig | None,
2230
):
23-
super().__init__(config, distributed_config)
24-
25-
self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size)
31+
super().__init__(
32+
config,
33+
distributed_config,
34+
hidden_dim=TensorDim("hidden", config.hidden_size),
35+
lr_scale=lr_scale,
36+
peft=peft,
37+
)
2638
self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer(
2739
distributed_config,
2840
hidden_dim=self._hidden_dim,
29-
lr_scale=None,
30-
peft=self._config.peft,
41+
lr_scale=self._lr_scale,
42+
peft=self._peft,
3143
)
3244
self.decoder = self._config.decoder.get_layer(
3345
distributed_config,
3446
self._hidden_dim,
35-
lr_scale=None,
36-
peft=self._config.peft,
47+
lr_scale=self._lr_scale,
48+
peft=self._peft,
3749
)
3850
self.head = self._config.head.get_layer(
3951
distributed_config,
4052
self._config.embeddings,
4153
hidden_dim=self._hidden_dim,
42-
lr_scale=None,
43-
peft=self._config.peft,
54+
lr_scale=self._lr_scale,
55+
peft=self._peft,
4456
)
4557

46-
def get_layers(self) -> list["Layer"]:
58+
def get_layers(self) -> list[Layer]:
4759
return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers()
4860

49-
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
61+
def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
5062
# Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable?
5163
self.embeddings.preprocess(batch, kwargs)
5264
self.decoder.preprocess(batch, kwargs)

fast_llm/models/gpt/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig
1111
from fast_llm.engine.schedule.config import BatchConfig
1212
from fast_llm.engine.training.config import TrainerConfig
13+
from fast_llm.layers.common.peft.config import PeftConfig
1314
from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig
1415
from fast_llm.models.gpt.conversion.config import (
1516
AprielHybridSSMCheckpointFormat,
@@ -84,6 +85,11 @@ def micro_batch_splits(self) -> int:
8485
class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig):
8586
_abstract = False
8687

88+
# TODO: Allow overriding in sub-models?
89+
peft: PeftConfig = Field(
90+
desc="Configuration for parameter-efficient fine tuning.",
91+
hint=FieldHint.architecture,
92+
)
8793
# Debug, to get an exact match with megatron init.
8894
use_megatron_initialization: bool = Field(
8995
default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing

fast_llm/models/gpt/conversion/llama.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,19 +449,13 @@ def get_converters(
449449
class LlamaEmbeddingsConverter:
450450
@classmethod
451451
def import_config(cls, config: dict) -> dict:
452-
return {
453-
"vocab_size": config["vocab_size"],
454-
"hidden_size": config["hidden_size"],
455-
}
452+
return {"vocab_size": config["vocab_size"]}
456453

457454
@classmethod
458455
def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict:
459456
Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig)
460457
assert not config.position_embeddings.enabled
461-
return {
462-
"vocab_size": config.vocab_size,
463-
"hidden_size": config.hidden_size,
464-
}
458+
return {"vocab_size": config.vocab_size}
465459

466460
@classmethod
467461
def get_converters(
@@ -516,6 +510,7 @@ def import_config(cls, config: dict) -> dict:
516510
"embeddings": cls.embeddings_converter_class.import_config(config),
517511
"decoder": cls.decoder_converter_class.import_config(config),
518512
"head": cls.head_converter_class.import_config(config),
513+
"hidden_size": config["hidden_size"],
519514
"tied_embedding_weight": config["tie_word_embeddings"],
520515
}
521516

@@ -526,7 +521,10 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict:
526521
cls.embeddings_converter_class.export_config(config.embeddings),
527522
cls.decoder_converter_class.export_config(config.decoder),
528523
cls.head_converter_class.export_config(config.head),
529-
{"tie_word_embeddings": config.tied_embedding_weight},
524+
{
525+
"tie_word_embeddings": config.tied_embedding_weight,
526+
"hidden_size": config.hidden_size,
527+
},
530528
)
531529

532530
@classmethod

fast_llm/models/gpt/model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,14 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], Ba
3030

3131
def __init__(
3232
self,
33-
config: GPTBaseModelConfig,
33+
config: ConfigType,
3434
distributed_config: DistributedConfig,
3535
):
36-
super().__init__(config, distributed_config)
36+
super().__init__(config, distributed_config, lr_scale=config.lr_scale, peft=config.peft)
3737
if self._config.use_megatron_initialization:
3838
for param in self.parameters():
3939
Assert.custom(isinstance, param, ParameterMeta)
40-
param.init_parameter = get_init_megatron(
41-
param, self._config.decoder.block, config.embeddings.hidden_size
42-
) # Noqa
40+
param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa
4341

4442
def preprocess_meta(
4543
self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType

0 commit comments

Comments
 (0)