Skip to content

Commit bda052f

Browse files
authored
Base model interface review (#370)
1 parent cc5ca89 commit bda052f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1403
-1264
lines changed

examples/mistral.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ optimizer:
2727
beta_2: 0.95
2828
model:
2929
base_model:
30-
embeddings_layer:
30+
embeddings:
3131
hidden_size: 4096
3232
vocab_size: 32000
3333
dropout: 0.0
@@ -54,11 +54,11 @@ model:
5454
epsilon: 1.0e-05
5555
dropout: 0.0
5656
num_blocks: 32
57-
output_layer:
58-
tied_weight: false
57+
head:
5958
normalization:
6059
type: rms_norm
6160
epsilon: 1.0e-05
61+
tied_embedding_weight: false
6262
multi_stage:
6363
zero_stage: 2
6464
distributed:

fast_llm/core/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def reduce_op(
2626
return (input_, handle) if async_op else input_
2727

2828

29-
def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
29+
def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor:
3030
"""Split the tensor along its last dimension and keep the
3131
corresponding slice."""
3232
if group:
@@ -139,11 +139,11 @@ class _Split(torch.autograd.Function):
139139
"""Split the input and keep only the corresponding chuck to the rank."""
140140

141141
@staticmethod
142-
def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa
142+
def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa
143143
return split_op(input_, group, dim)
144144

145145
@staticmethod
146-
def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa
146+
def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa
147147
ctx.group = group
148148
ctx.dim = dim
149149
return split_op(input_, group, dim)
@@ -209,7 +209,7 @@ def reduce_backward(input_: torch.Tensor, group: ProcessGroup | None) -> torch.T
209209

210210

211211
@torch._dynamo.disable # noqa
212-
def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
212+
def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor:
213213
return _Split.apply(input_, group, dim)
214214

215215

fast_llm/engine/base_model/base_model.py

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
import abc
22
import typing
33

4-
import torch
54
import torch.nn
65

76
from fast_llm.config import Configurable
8-
from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig
7+
from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig
98
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
109
from fast_llm.engine.distributed.distributed import Distributed
1110
from fast_llm.tensor import ParameterMeta, TensorMeta
12-
from fast_llm.utils import Assert
1311

1412
if typing.TYPE_CHECKING:
1513
from fast_llm.engine.inference.runner import InferenceRunner
1614

1715

18-
class Module(torch.nn.Module, abc.ABC):
19-
""" """
20-
16+
class LayerBase(torch.nn.Module, abc.ABC):
2117
_is_setup: bool = False
2218
_distributed: Distributed
2319

@@ -27,85 +23,121 @@ def __init__(self, distributed_config: DistributedConfig):
2723

2824
def setup(self, distributed: Distributed) -> None:
2925
assert not self._is_setup
26+
for layer in self.get_layers():
27+
if layer is not self:
28+
layer.setup(distributed)
3029
distributed.check_config(self._distributed_config)
3130
self._distributed = distributed
3231
self._is_setup = True
3332

33+
@abc.abstractmethod
34+
def get_layers(self) -> list["Layer"]:
35+
"""
36+
The list of layers as meant to be seen by the Fast-LLM engine.
37+
May differ from the module configuration seen by pytorch.
38+
"""
3439

35-
class Layer(Module):
36-
# Weight used to determine the stage size
40+
def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
41+
out = 0
42+
for layer in self.get_layers():
43+
if layer is self:
44+
raise NotImplementedError()
45+
out += layer.get_compute_usage(input_, kwargs, config)
46+
return out
47+
48+
def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
49+
losses = []
50+
for layer in self.get_layers():
51+
if layer is not self:
52+
losses += layer.get_loss_definitions(count)
53+
return losses
54+
55+
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
56+
for layer in self.get_layers():
57+
if layer is not self:
58+
layer.preprocess(batch, kwargs)
59+
60+
61+
class Layer(LayerBase):
62+
# Weight used to determine the stage size.
3763
layer_count: float = 1.0
3864

65+
def get_layers(self) -> list["Layer"]:
66+
# Return a breakdown of the layer into atomic ones,
67+
# i.e. the list of layers from as seen from the Fast-LLM model.
68+
return [self]
69+
3970
@abc.abstractmethod
4071
def forward(
4172
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
4273
) -> torch.Tensor:
4374
pass
4475

45-
def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
46-
raise NotImplementedError()
47-
76+
def unwrap(self) -> "Layer":
77+
# Get the actual module contained in this layer,
78+
# undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`)
79+
return self
4880

49-
class Sequential(Layer):
50-
def __init__(self, distributed_config: DistributedConfig):
51-
super().__init__(distributed_config)
52-
self.layers = torch.nn.ModuleList(self.get_layers())
5381

54-
def __getitem__(self, item):
55-
return self.layers[item]
82+
class LayerWithNamespace(Layer):
83+
"""
84+
A layer with its own namespace for preprocessing (kwargs),
85+
so that it doesn't inadvertently interact with other layers.
86+
TODO: Consider namespace for losses and metrics?
87+
"""
5688

57-
def __iter__(self):
58-
return iter(self.layers)
89+
def __init__(self, layer: Layer, namespace: str = None):
90+
super().__init__(layer._distributed_config)
91+
self._layer = layer
92+
self._namespace = namespace
93+
self.layer_count = self._layer.layer_count
94+
self.get_compute_usage = self._layer.get_compute_usage
95+
self.module_name = self._layer.module_name
5996

60-
def __len__(self):
61-
return len(self.layers)
97+
def setup(self, distributed: Distributed) -> None:
98+
self._layer.setup(distributed)
99+
super().setup(distributed)
62100

63101
def forward(
64102
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
65103
) -> torch.Tensor:
66-
for layer in self.layers:
67-
input_ = layer(input_, kwargs, losses, metrics)
68-
return input_
104+
if self._namespace in kwargs:
105+
kwargs = kwargs[self._namespace]
106+
else:
107+
# TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace.
108+
# Using kwargs as-is since it's generally unused.
109+
assert isinstance(input_, TensorMeta)
110+
return self._layer.forward(input_, kwargs, losses, metrics)
69111

70-
@abc.abstractmethod
71-
def get_layers(self) -> list[Layer]:
72-
pass
112+
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
113+
assert self._namespace not in kwargs
114+
kwargs[self._namespace] = kwargs.copy()
115+
self._layer.preprocess(batch, kwargs[self._namespace])
73116

74-
def setup(self, distributed: Distributed) -> None:
75-
super().setup(distributed)
76-
for layer in self.layers:
77-
layer.setup(distributed)
117+
def unwrap(self) -> "Layer":
118+
return self._layer.unwrap()
78119

79120

80-
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential):
121+
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase):
81122

82123
def __init__(
83124
self,
84125
config: BaseModelConfig,
85126
distributed_config: DistributedConfig,
86127
):
87128
super().__init__(config, distributed_config)
88-
for key, value in self.named_modules():
89-
value.module_name = key
90-
for key, value in self.named_parameters():
91-
Assert.custom(isinstance, value, ParameterMeta)
92-
# Rename to the parameter full name
93-
value.tensor_name = key
94129

95130
# Reference models
96131
# TODO: Add basic handling (preprocessor) in this class.
97132
self._reference_models: dict[str, "InferenceRunner"] = {}
98133

99-
@abc.abstractmethod
100-
def get_layers(self) -> list[Layer]:
101-
pass
102-
103134
@abc.abstractmethod
104135
def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]:
136+
# TODO Remove (Move batch splitting elsewhere)
105137
pass
106138

107139
@abc.abstractmethod
108-
def preprocess(
140+
def preprocess_batch(
109141
self,
110142
batch: typing.Any,
111143
preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None,
@@ -114,13 +146,19 @@ def preprocess(
114146
iteration: int,
115147
metrics: dict | None = None,
116148
) -> list[tuple[torch.Tensor, dict]]:
149+
# TODO Move batch splitting elsewhere, align interface with LayerBase
117150
pass
118151

119-
def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
120-
# For each tied weight, return the weight and the tuple of layers sharing it.
121-
# The weight should be defined in the first layer in the set.
122-
# Warning: This may return buffers instead of metas after stage setup.
123-
# The name (dict key) is used to insert the weight in the kwargs of the forward pass.
152+
def get_tied_parameters(self) -> dict[str, list[ParameterMeta]]:
153+
"""
154+
Return tuples of independently defined metas to tie together.
155+
Metas should be compatible, i.e. have the same tensor dimensions.
156+
Tied weights are named (dict keys) for convenience only.
157+
Warning: Initialization and optimization properties are defined on the first appearance of the tied weight.
158+
To prevent any confusion, the metas should be provided in the same order they appear in the model.
159+
TODO: Improve?
160+
Note: This may return buffers instead of metas after stage setup.
161+
"""
124162
return {}
125163

126164
def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None:

fast_llm/engine/base_model/config.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class
66
from fast_llm.engine.config_utils.data_type import DataType
7-
from fast_llm.utils import compare_nested, log
7+
from fast_llm.engine.distributed.config import DistributedConfig
8+
from fast_llm.utils import Assert, compare_nested, log
89

910
if typing.TYPE_CHECKING:
10-
import torch
11+
from fast_llm.engine.base_model.base_model import BaseModel
1112

1213

1314
@config_class()
14-
class BaseModelConfig(Config):
15+
class ModuleConfig(Config):
1516
"""
1617
Abstract config class for a base model.
1718
# TODO: Find better name?
@@ -43,7 +44,7 @@ def _get_architecture(self) -> dict[str, typing.Any]:
4344
return architecture
4445

4546
def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
46-
if isinstance(value, BaseModelConfig):
47+
if isinstance(value, ModuleConfig):
4748
# TODO: Make sure all nested configs have an architecture type hint?
4849
return value._get_architecture()
4950
elif isinstance(value, Config):
@@ -57,12 +58,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
5758
return self._serialize_value(value)
5859

5960

60-
class Preprocessor(abc.ABC):
61-
def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
62-
pass
61+
@config_class()
62+
class BaseModelConfig(ModuleConfig):
63+
"""
64+
Abstract config class for a base model.
65+
"""
66+
67+
def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel":
68+
from fast_llm.tensor import ParameterMeta
69+
70+
model = self.base_model_class(self, distributed_config)
71+
# Storing the global name of each module and tensor.
72+
# Done here because it needs to run right after `model.__init__()`
73+
for key, value in model.named_modules():
74+
value.module_name = key
75+
for key, value in model.named_parameters():
76+
Assert.custom(isinstance, value, ParameterMeta)
77+
# Rename to the parameter full name
78+
value.tensor_name = key
79+
return model
6380

81+
@property
6482
@abc.abstractmethod
65-
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
83+
def base_model_class(self) -> type["BaseModel"]:
6684
pass
6785

6886

fast_llm/engine/checkpoint/huggingface.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,18 @@ def export_config(cls, config: BaseModelConfig) -> dict:
3131

3232
@classmethod
3333
@abc.abstractmethod
34-
def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]:
34+
def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]:
3535
pass
3636

3737

3838
class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC):
3939
architecture: typing.ClassVar[str]
4040
base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]]
4141

42+
def __init__(self, model: "FastLLMModel"):
43+
self._exported_config = self._export_config(model.config)
44+
super().__init__(model)
45+
4246
@classmethod
4347
@abc.abstractmethod
4448
def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]:
@@ -126,10 +130,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
126130
Assert.eq(config["architecture"], cls.architecture)
127131
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})
128132

129-
def _create_weight_converters(
130-
self,
131-
) -> list[WeightConverter]:
132-
return self.base_model_converter_class.get_converters(self._model.config.base_model)
133+
def _create_weight_converters(self) -> list[WeightConverter]:
134+
return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config)
133135

134136
def _load_weights(
135137
self, config: CheckpointLoadConfig, device

fast_llm/engine/config_utils/data_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from triton import language as tl
1010

1111

12-
class DataType(str, enum.Enum):
12+
class DataType(enum.StrEnum):
1313
"""
1414
An enum to represent data types independently of third party libraries,
1515
so we can swap them more easily and allow for lazy imports.

fast_llm/engine/config_utils/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0
137137
)
138138
config_dict = config.to_dict()
139-
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance)
139+
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug)
140140

141141
if self._config.experiment_dir is not None:
142142
self._experiment_directory = self._config.experiment_dir.resolve()

fast_llm/engine/evaluation/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def setup(
116116
phase=PhaseType.validation,
117117
)
118118

119-
self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions()
119+
self._loss_defs = self._multi_stage.base_model.get_loss_definitions()
120120
self._evaluation_iterator = None
121121
self._is_setup = True
122122

0 commit comments

Comments
 (0)