Skip to content
8 changes: 4 additions & 4 deletions fast_llm/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def reduce_op(
return (input_, handle) if async_op else input_


def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
if group:
Expand Down Expand Up @@ -139,11 +139,11 @@ class _Split(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""

@staticmethod
def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa
def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa
return split_op(input_, group, dim)

@staticmethod
def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa
def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa
ctx.group = group
ctx.dim = dim
return split_op(input_, group, dim)
Expand Down Expand Up @@ -209,7 +209,7 @@ def reduce_backward(input_: torch.Tensor, group: ProcessGroup | None) -> torch.T


@torch._dynamo.disable # noqa
def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor:
return _Split.apply(input_, group, dim)


Expand Down
108 changes: 64 additions & 44 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import abc
import typing

import torch
import torch.nn

from fast_llm.config import Configurable
from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig
from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import ParameterMeta, TensorMeta
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.inference.runner import InferenceRunner


class Module(torch.nn.Module, abc.ABC):
""" """

class LayerBase(torch.nn.Module, abc.ABC):
_is_setup: bool = False
_distributed: Distributed

Expand All @@ -27,81 +23,102 @@ def __init__(self, distributed_config: DistributedConfig):

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
for layer in self.get_layers():
if layer is not self:
layer.setup(distributed)
distributed.check_config(self._distributed_config)
self._distributed = distributed
self._is_setup = True

@abc.abstractmethod
def get_layers(self) -> list["Layer"]:
"""
The list of layers as meant to be seen by the Fast-LLM engine.
May differ from the module configuration seen by pytorch.
"""

class Layer(Module):
# Weight used to determine the stage size
def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
out = 0
for layer in self.get_layers():
if layer is self:
raise NotImplementedError()
out += layer.get_compute_usage(input_, kwargs, config)
return out

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
losses = []
for layer in self.get_layers():
if layer is not self:
losses += layer.get_loss_definitions(count)
return losses

def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
for layer in self.get_layers():
if layer is not self:
layer.preprocess(batch, kwargs)


class Layer(LayerBase):
# Weight used to determine the stage size.
layer_count: float = 1.0

def get_layers(self) -> list["Layer"]:
# Return a breakdown of the layer into atomic ones,
# i.e. the list of layers from as seen from the Fast-LLM model.
return [self]

@abc.abstractmethod
def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
pass

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
raise NotImplementedError()

class LayerWithNamespace(Layer):
"""
A layer with its own namespace for preprocessing (kwargs),
so that it doesn't inadvertently interact with other layers.
TODO: Consider namespace for losses and metrics?
"""

class Sequential(Layer):
def __init__(self, distributed_config: DistributedConfig):
super().__init__(distributed_config)
self.layers = torch.nn.ModuleList(self.get_layers())

def __getitem__(self, item):
return self.layers[item]
def __init__(self, layer: Layer, namespace: str):
super().__init__(layer._distributed_config)
self._layer = layer
self._namespace = namespace
self.layer_count = self._layer.layer_count
self.get_compute_usage = self._layer.get_compute_usage

def __iter__(self):
return iter(self.layers)

def __len__(self):
return len(self.layers)
def setup(self, distributed: Distributed) -> None:
self._layer.setup(distributed)
super().setup(distributed)

def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
for layer in self.layers:
input_ = layer(input_, kwargs, losses, metrics)
return input_
return self._layer.forward(input_, kwargs[self._namespace], losses, metrics)

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
assert self._namespace not in kwargs
kwargs[self._namespace] = kwargs.copy()
return self._layer.preprocess(batch, kwargs[self._namespace])

def setup(self, distributed: Distributed) -> None:
super().setup(distributed)
for layer in self.layers:
layer.setup(distributed)


class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential):
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase):

def __init__(
self,
config: BaseModelConfig,
distributed_config: DistributedConfig,
):
super().__init__(config, distributed_config)
for key, value in self.named_modules():
value.module_name = key
for key, value in self.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
value.tensor_name = key

# Reference models
# TODO: Add basic handling (preprocessor) in this class.
self._reference_models: dict[str, "InferenceRunner"] = {}

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass

@abc.abstractmethod
def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]:
# TODO ====== Remove (Move batch splitting elsewhere) ======
pass

@abc.abstractmethod
Expand All @@ -114,9 +131,12 @@ def preprocess(
iteration: int,
metrics: dict | None = None,
) -> list[tuple[torch.Tensor, dict]]:
# TODO ====== Move batch splitting elsewhere, align interface with LayerBase ======
pass

def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
# TODO ====== Tied weights ======
# Return tuples of independently defined metas to tie together.
# For each tied weight, return the weight and the tuple of layers sharing it.
# The weight should be defined in the first layer in the set.
# Warning: This may return buffers instead of metas after stage setup.
Expand Down
34 changes: 26 additions & 8 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import compare_nested, log
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.utils import Assert, compare_nested, log

if typing.TYPE_CHECKING:
import torch
from fast_llm.engine.base_model.base_model import BaseModel


@config_class()
class BaseModelConfig(Config):
class ModuleConfig(Config):
"""
Abstract config class for a base model.
# TODO: Find better name?
Expand Down Expand Up @@ -43,7 +44,7 @@ def _get_architecture(self) -> dict[str, typing.Any]:
return architecture

def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
if isinstance(value, BaseModelConfig):
if isinstance(value, ModuleConfig):
# TODO: Make sure all nested configs have an architecture type hint?
return value._get_architecture()
elif isinstance(value, Config):
Expand All @@ -57,12 +58,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
return self._serialize_value(value)


class Preprocessor(abc.ABC):
def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
pass
@config_class()
class BaseModelConfig(ModuleConfig):
"""
Abstract config class for a base model.
"""

def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel":
from fast_llm.tensor import ParameterMeta

model = self.base_model_class(self, distributed_config)
# Storing the global name of each module and tensor.
# Done here because it needs to run right after `model.__init__()`
for key, value in model.named_modules():
value.module_name = key
for key, value in model.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
value.tensor_name = key
return model

@property
@abc.abstractmethod
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
def base_model_class(self) -> type["BaseModel"]:
pass


Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from triton import language as tl


class DataType(str, enum.Enum):
class DataType(enum.StrEnum):
"""
An enum to represent data types independently of third party libraries,
so we can swap them more easily and allow for lazy imports.
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0
)
config_dict = config.to_dict()
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance)
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug)

if self._config.experiment_dir is not None:
self._experiment_directory = self._config.experiment_dir.resolve()
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def setup(
phase=PhaseType.validation,
)

self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions()
self._loss_defs = self._multi_stage.base_model.get_loss_definitions()
self._evaluation_iterator = None
self._is_setup = True

Expand Down
5 changes: 0 additions & 5 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,6 @@ class StageConfig(Config):
desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead",
hint=FieldHint.logging,
)
compile_all: bool = Field(
default=False,
desc="Compile the whole model using torch.compile.",
hint=FieldHint.expert,
)


@config_class()
Expand Down
15 changes: 5 additions & 10 deletions fast_llm/engine/multi_stage/multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]):
base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel
_is_setup: bool = False
_flat_shard: torch.Tensor
_shards: dict[str, torch.Tensor]
Expand All @@ -46,7 +45,8 @@ def __init__(
stage_filter: set | None = None,
):
super().__init__(config)
self._base_model = self.base_model_class(self._config.base_model, self._config.distributed)
self._base_model = self._config.base_model.get_base_model(self._config.distributed)
self._layers = self._base_model.get_layers()
self._training = None
self._verbose = verbose
self._stage_filter = stage_filter
Expand All @@ -67,10 +67,8 @@ def __init__(
self._stages = [
Stage(
config=self._config.multi_stage,
base_model=self._base_model,
layers=self._layers[stage_splits[i] : stage_splits[i + 1]],
distributed_config=self._config.distributed,
begin=stage_splits[i],
end=stage_splits[i + 1],
index=i,
)
for i in (range(self._num_stages))
Expand Down Expand Up @@ -510,12 +508,9 @@ def _split_into_stages(self) -> list[int]:
# Create stages (greedy split, could do better).
stage_splits = [0]
layer_counter, last_counter = 0, 0
for i, layer in enumerate(self._base_model):
for i, layer in enumerate(self._layers):
layer_counter += layer.layer_count # noqa
if (
layer_counter >= last_counter + self._config.multi_stage.layers_per_stage
or i == len(self._base_model) - 1
):
if layer_counter >= last_counter + self._config.multi_stage.layers_per_stage or i == len(self._layers) - 1:
stage_splits.append(i + 1)
last_counter = layer_counter
return stage_splits
Expand Down
Loading
Loading