11import abc
22import typing
33
4- import torch
54import torch .nn
65
76from fast_llm .config import Configurable
8- from fast_llm .engine .base_model .config import BaseModelConfig , ResourceUsageConfig
7+ from fast_llm .engine .base_model .config import BaseModelConfig , LossDef , ResourceUsageConfig
98from fast_llm .engine .distributed .config import DistributedConfig , PhaseType
109from fast_llm .engine .distributed .distributed import Distributed
1110from fast_llm .tensor import ParameterMeta , TensorMeta
12- from fast_llm .utils import Assert
1311
1412if typing .TYPE_CHECKING :
1513 from fast_llm .engine .inference .runner import InferenceRunner
1614
1715
18- class Module (torch .nn .Module , abc .ABC ):
19- """ """
20-
16+ class LayerBase (torch .nn .Module , abc .ABC ):
2117 _is_setup : bool = False
2218 _distributed : Distributed
2319
@@ -27,85 +23,121 @@ def __init__(self, distributed_config: DistributedConfig):
2723
2824 def setup (self , distributed : Distributed ) -> None :
2925 assert not self ._is_setup
26+ for layer in self .get_layers ():
27+ if layer is not self :
28+ layer .setup (distributed )
3029 distributed .check_config (self ._distributed_config )
3130 self ._distributed = distributed
3231 self ._is_setup = True
3332
33+ @abc .abstractmethod
34+ def get_layers (self ) -> list ["Layer" ]:
35+ """
36+ The list of layers as meant to be seen by the Fast-LLM engine.
37+ May differ from the module configuration seen by pytorch.
38+ """
3439
35- class Layer (Module ):
36- # Weight used to determine the stage size
40+ def get_compute_usage (self , input_ : TensorMeta , kwargs : dict [str , typing .Any ], config : ResourceUsageConfig ) -> int :
41+ out = 0
42+ for layer in self .get_layers ():
43+ if layer is self :
44+ raise NotImplementedError ()
45+ out += layer .get_compute_usage (input_ , kwargs , config )
46+ return out
47+
48+ def get_loss_definitions (self , count : int = 1 ) -> list [LossDef ]:
49+ losses = []
50+ for layer in self .get_layers ():
51+ if layer is not self :
52+ losses += layer .get_loss_definitions (count )
53+ return losses
54+
55+ def preprocess (self , batch : "torch.Tensor" , kwargs : dict [str , typing .Any ]) -> None :
56+ for layer in self .get_layers ():
57+ if layer is not self :
58+ layer .preprocess (batch , kwargs )
59+
60+
61+ class Layer (LayerBase ):
62+ # Weight used to determine the stage size.
3763 layer_count : float = 1.0
3864
65+ def get_layers (self ) -> list ["Layer" ]:
66+ # Return a breakdown of the layer into atomic ones,
67+ # i.e. the list of layers from as seen from the Fast-LLM model.
68+ return [self ]
69+
3970 @abc .abstractmethod
4071 def forward (
4172 self , input_ : torch .Tensor , kwargs : dict , losses : dict | None = None , metrics : dict | None = None
4273 ) -> torch .Tensor :
4374 pass
4475
45- def get_compute_usage (self , input_ : TensorMeta , kwargs : dict [str , typing .Any ], config : ResourceUsageConfig ) -> int :
46- raise NotImplementedError ()
47-
76+ def unwrap (self ) -> "Layer" :
77+ # Get the actual module contained in this layer,
78+ # undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`)
79+ return self
4880
49- class Sequential (Layer ):
50- def __init__ (self , distributed_config : DistributedConfig ):
51- super ().__init__ (distributed_config )
52- self .layers = torch .nn .ModuleList (self .get_layers ())
5381
54- def __getitem__ (self , item ):
55- return self .layers [item ]
82+ class LayerWithNamespace (Layer ):
83+ """
84+ A layer with its own namespace for preprocessing (kwargs),
85+ so that it doesn't inadvertently interact with other layers.
86+ TODO: Consider namespace for losses and metrics?
87+ """
5688
57- def __iter__ (self ):
58- return iter (self .layers )
89+ def __init__ (self , layer : Layer , namespace : str = None ):
90+ super ().__init__ (layer ._distributed_config )
91+ self ._layer = layer
92+ self ._namespace = namespace
93+ self .layer_count = self ._layer .layer_count
94+ self .get_compute_usage = self ._layer .get_compute_usage
95+ self .module_name = self ._layer .module_name
5996
60- def __len__ (self ):
61- return len (self .layers )
97+ def setup (self , distributed : Distributed ) -> None :
98+ self ._layer .setup (distributed )
99+ super ().setup (distributed )
62100
63101 def forward (
64102 self , input_ : torch .Tensor , kwargs : dict , losses : dict | None = None , metrics : dict | None = None
65103 ) -> torch .Tensor :
66- for layer in self .layers :
67- input_ = layer (input_ , kwargs , losses , metrics )
68- return input_
104+ if self ._namespace in kwargs :
105+ kwargs = kwargs [self ._namespace ]
106+ else :
107+ # TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace.
108+ # Using kwargs as-is since it's generally unused.
109+ assert isinstance (input_ , TensorMeta )
110+ return self ._layer .forward (input_ , kwargs , losses , metrics )
69111
70- @abc .abstractmethod
71- def get_layers (self ) -> list [Layer ]:
72- pass
112+ def preprocess (self , batch : "torch.Tensor" , kwargs : dict [str , typing .Any ]) -> None :
113+ assert self ._namespace not in kwargs
114+ kwargs [self ._namespace ] = kwargs .copy ()
115+ self ._layer .preprocess (batch , kwargs [self ._namespace ])
73116
74- def setup (self , distributed : Distributed ) -> None :
75- super ().setup (distributed )
76- for layer in self .layers :
77- layer .setup (distributed )
117+ def unwrap (self ) -> "Layer" :
118+ return self ._layer .unwrap ()
78119
79120
80- class BaseModel [ConfigType : BaseModelConfig ](Configurable [ConfigType ], Sequential ):
121+ class BaseModel [ConfigType : BaseModelConfig ](Configurable [ConfigType ], LayerBase ):
81122
82123 def __init__ (
83124 self ,
84125 config : BaseModelConfig ,
85126 distributed_config : DistributedConfig ,
86127 ):
87128 super ().__init__ (config , distributed_config )
88- for key , value in self .named_modules ():
89- value .module_name = key
90- for key , value in self .named_parameters ():
91- Assert .custom (isinstance , value , ParameterMeta )
92- # Rename to the parameter full name
93- value .tensor_name = key
94129
95130 # Reference models
96131 # TODO: Add basic handling (preprocessor) in this class.
97132 self ._reference_models : dict [str , "InferenceRunner" ] = {}
98133
99- @abc .abstractmethod
100- def get_layers (self ) -> list [Layer ]:
101- pass
102-
103134 @abc .abstractmethod
104135 def preprocess_meta (self , batch_meta : typing .Any , phase : PhaseType ) -> list [tuple [TensorMeta , dict ]]:
136+ # TODO Remove (Move batch splitting elsewhere)
105137 pass
106138
107139 @abc .abstractmethod
108- def preprocess (
140+ def preprocess_batch (
109141 self ,
110142 batch : typing .Any ,
111143 preprocessed_meta : list [tuple [TensorMeta , dict ]] | None = None ,
@@ -114,13 +146,19 @@ def preprocess(
114146 iteration : int ,
115147 metrics : dict | None = None ,
116148 ) -> list [tuple [torch .Tensor , dict ]]:
149+ # TODO Move batch splitting elsewhere, align interface with LayerBase
117150 pass
118151
119- def get_tied_weights (self ) -> dict [str , tuple [ParameterMeta , tuple [int , ...]]]:
120- # For each tied weight, return the weight and the tuple of layers sharing it.
121- # The weight should be defined in the first layer in the set.
122- # Warning: This may return buffers instead of metas after stage setup.
123- # The name (dict key) is used to insert the weight in the kwargs of the forward pass.
152+ def get_tied_parameters (self ) -> dict [str , list [ParameterMeta ]]:
153+ """
154+ Return tuples of independently defined metas to tie together.
155+ Metas should be compatible, i.e. have the same tensor dimensions.
156+ Tied weights are named (dict keys) for convenience only.
157+ Warning: Initialization and optimization properties are defined on the first appearance of the tied weight.
158+ To prevent any confusion, the metas should be provided in the same order they appear in the model.
159+ TODO: Improve?
160+ Note: This may return buffers instead of metas after stage setup.
161+ """
124162 return {}
125163
126164 def add_reference_model (self , name : str , inference_runner : "InferenceRunner" ) -> None :
0 commit comments