Skip to content

Commit 0e1c23c

Browse files
oleksostjlamypoirierRaymondLi0nitsanluke
authored
Apriel SSM/Hybrid (#258)
Co-authored-by: Joel Lamy-Poirier <[email protected]> Co-authored-by: Toolkit User <[email protected]> Co-authored-by: Luke Nitish Kumar <[email protected]>
1 parent a134722 commit 0e1c23c

32 files changed

+5623
-173
lines changed

fast_llm/engine/optimizer/learning_rate.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR
120120
begin_step = 0
121121
for stage_arg_str in config.schedule.split(";"):
122122
try:
123-
for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","):
124-
assert begin_step is not None
125-
num_steps = int(num_steps)
126-
end_step = None if num_steps < 0 else begin_step + num_steps
127-
kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)}
128-
if len(stage_args) > 0:
129-
kwargs["end_lr"] = float(stage_args[0])
130-
if len(stage_args) > 1:
131-
kwargs["power"] = float(stage_args[1])
132-
if len(stage_args) > 2:
133-
raise ValueError(stage_args[2:])
134-
stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs))
135-
begin_step = end_step
123+
stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",")
124+
assert begin_step is not None
125+
num_steps = int(num_steps)
126+
end_step = None if num_steps < 0 else begin_step + num_steps
127+
kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)}
128+
if len(stage_args) > 0:
129+
kwargs["end_lr"] = float(stage_args[0])
130+
if len(stage_args) > 1:
131+
kwargs["power"] = float(stage_args[1])
132+
if len(stage_args) > 2:
133+
raise ValueError(stage_args[2:])
134+
stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs))
135+
begin_step = end_step
136136
except Exception:
137137
raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"')
138138
return LearningRateSchedule(stages)

fast_llm/layers/common/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
1212

1313

14+
@config_class()
15+
class LLMBlockConfig(BaseModelConfig):
16+
_abstract = False
17+
18+
per_layer_lr_scale: list[float] | None = Field(
19+
default=None,
20+
desc="Custom learning rate scale for each layer.",
21+
doc="May be used to freeze some layers by setting their scale to zero.",
22+
hint=FieldHint.feature,
23+
)
24+
25+
1426
class NormalizationImplementation(str, enum.Enum):
1527
"""
1628
An enum for the available implementations of layer norm.
@@ -68,7 +80,7 @@ class NormalizationConfig(BaseModelConfig):
6880
valid=check_field(Assert.geq, 0),
6981
)
7082

71-
def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm":
83+
def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
7284
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
7385
from fast_llm.tensor import init_uniform_
7486

@@ -77,6 +89,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm":
7789
"eps": self.epsilon,
7890
"implementation": self.implementation,
7991
"zero_centered": self.zero_centered,
92+
"lr_scale": lr_scale,
8093
}
8194
if self.initialization_range:
8295
mean = 0 if self.zero_centered else 1

fast_llm/layers/common/normalization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
weight_init_method=None,
156156
bias_init_method=init_zeros_,
157157
zero_centered: bool = False,
158+
lr_scale: float | None = None,
158159
):
159160
super().__init__()
160161
assert hidden_dim.parallel_dim is None
@@ -193,12 +194,14 @@ def __init__(
193194
init_method=weight_init_method,
194195
weight_decay=False,
195196
auto_grad_accumulation=implementation == NormalizationImplementation.torch,
197+
lr_scale=lr_scale,
196198
)
197199
self.bias = ParameterMeta.from_dims(
198200
(hidden_dim,),
199201
init_method=bias_init_method,
200202
weight_decay=False,
201203
auto_grad_accumulation=implementation == NormalizationImplementation.torch,
204+
lr_scale=lr_scale,
202205
)
203206
self.normalized_shape = self.weight.shape
204207

@@ -236,6 +239,7 @@ def __init__(
236239
implementation: NormalizationImplementation = NormalizationImplementation.auto,
237240
weight_init_method=None,
238241
zero_centered: bool = False,
242+
lr_scale: float | None = None,
239243
):
240244
super().__init__()
241245
assert hidden_dim.parallel_dim is None
@@ -269,6 +273,7 @@ def __init__(
269273
init_method=weight_init_method,
270274
weight_decay=False,
271275
auto_grad_accumulation=True,
276+
lr_scale=lr_scale,
272277
)
273278
self.normalized_shape = self.weight.shape
274279

fast_llm/layers/language_model/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,25 @@ class LanguageModelBaseConfig(BaseModelConfig):
155155
hint=FieldHint.feature,
156156
valid=check_field(Assert.geq, 0),
157157
)
158+
embeddings_lr_scale: float | None = Field(
159+
default=None,
160+
desc="Learning rate scale for the word embeddings.",
161+
doc="May be used to freeze some layers by setting their scale to zero.",
162+
hint=FieldHint.feature,
163+
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
164+
)
165+
output_lr_scale: float | None = Field(
166+
default=None,
167+
desc="Custom learning rate scale for the output weights.",
168+
doc="May be used to freeze the output weights by setting their scale to zero.",
169+
hint=FieldHint.feature,
170+
)
171+
prediction_loss_coefficient: list[float] | None = Field(
172+
default=None,
173+
desc="Loss coefficient for each prediction head.",
174+
doc="If not provided, all heads are equally weighted.",
175+
hint=FieldHint.feature,
176+
)
158177

159178
def _validate(self) -> None:
160179
self.transformer.validate()
@@ -173,6 +192,10 @@ def _validate(self) -> None:
173192
if self.distillation_model is not None:
174193
if self.prediction_heads > 1:
175194
raise NotImplementedError("Multi-token prediction not supported with distillation.")
195+
if isinstance(self.prediction_loss_coefficient, list):
196+
Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads)
197+
for coeff in self.prediction_loss_coefficient:
198+
Assert.geq(coeff, 0)
176199

177200
def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
178201
self.transformer.setup_tensor_space(tensor_space)

fast_llm/layers/language_model/embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
min_val=config.init_method_min_embed,
6363
max_val=config.init_method_max_embed,
6464
),
65+
lr_scale=config.embeddings_lr_scale,
6566
)
6667
if self._use_absolute_position_embeddings:
6768
self.position_embeddings_weight = ParameterMeta.from_dims(
@@ -72,6 +73,7 @@ def __init__(
7273
max_val=config.init_method_max_embed,
7374
),
7475
allow_sequence_tensor_parallel=not config.parallel_embeddings,
76+
lr_scale=config.embeddings_lr_scale,
7577
)
7678

7779
# PEFT.

fast_llm/layers/language_model/head.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(
6060

6161
hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
6262

63+
self._loss_coefficient = (
64+
config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0
65+
)
6366
self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance)
6467
self.final_norm = config.transformer.normalization.get_layer(hidden_dim)
6568
self._logits_scale_factor = config.logits_scale_factor
@@ -109,6 +112,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
109112
min_val=config.init_method_min_embed,
110113
max_val=config.init_method_max_embed,
111114
),
115+
lr_scale=config.output_lr_scale,
112116
)
113117

114118
def forward(
@@ -139,7 +143,7 @@ def forward(
139143
else:
140144
if self.training:
141145
# Backward hook to compute the gradient of the loss
142-
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
146+
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient)
143147
# MTP: Return shared_hidden to be used by the next head.
144148
return shared_hidden
145149

fast_llm/layers/ssm/config.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from fast_llm.config import Field, FieldHint, check_field, config_class
2-
from fast_llm.engine.base_model.config import BaseModelConfig
1+
import enum
2+
3+
from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
34
from fast_llm.functional.config import ActivationType
4-
from fast_llm.layers.common.config import NormalizationConfig
5+
from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig
56
from fast_llm.utils import Assert
67

78

@@ -20,8 +21,19 @@ class SSMDimNames:
2021
v_heads = "v_heads" # Number of V heads
2122

2223

24+
class SSMBlockType(str, enum.Enum):
25+
"""
26+
An enum for the available mamba types for the MLP layer.
27+
"""
28+
29+
mamba = "m"
30+
mamba2_discrete = "m2d"
31+
mamba2 = "m2"
32+
transformer = "t"
33+
34+
2335
@config_class()
24-
class SSMConfig(BaseModelConfig):
36+
class SSMConfig(LLMBlockConfig):
2537
_abstract = False
2638

2739
# Normalization
@@ -53,7 +65,8 @@ class SSMConfig(BaseModelConfig):
5365
desc="Whether to use bias in SSM layers",
5466
hint=FieldHint.architecture,
5567
)
56-
dt_rank: int = Field(
68+
69+
dt_rank: None | int = Field(
5770
default=None,
5871
desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)",
5972
hint=FieldHint.architecture,
@@ -102,12 +115,22 @@ class SSMConfig(BaseModelConfig):
102115
valid=check_field(Assert.gt, 0),
103116
)
104117

118+
d_inner: None | int = Field(
119+
default=None,
120+
desc="Inner dimension for Mamba2 blocks.",
121+
hint=FieldHint.core,
122+
)
123+
mamba_lr_scale: float | None = Field(
124+
default=None,
125+
desc="Learning rate scale for Mamba blocks.",
126+
hint=FieldHint.feature,
127+
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
128+
)
129+
105130
def _validate(self) -> None:
106131
with self._set_implicit_default():
107132
if self.activation_type is None:
108133
self.activation_type = ActivationType.silu
109-
if self.dt_rank is None:
110-
self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation
111134

112135
super()._validate()
113136
Assert.geq(self.dt_max, self.dt_min)

fast_llm/layers/ssm/discrete_mamba2.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import logging
12
import math
23

3-
import causal_conv1d
44
import einops
55
import mamba_ssm.ops.triton.ssd_combined
66
import torch
@@ -9,6 +9,16 @@
99
from fast_llm.layers.common.linear import Linear
1010
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames
1111
from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_
12+
from fast_llm.utils import get_lr_scale
13+
14+
logger = logging.getLogger(__name__)
15+
16+
try:
17+
import causal_conv1d
18+
except ImportError:
19+
# this is needed since we cannot use causal_conv1d on B200 GPUs for now
20+
logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead")
21+
causal_conv1d = None
1222

1323
"""
1424
This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py
@@ -44,6 +54,9 @@ def __init__(
4454
bias = config.add_bias_linear
4555
self.layer_idx = layer_idx
4656
self._return_input = return_input
57+
layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None
58+
mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale)
59+
logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}")
4760

4861
td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim)
4962
td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim)
@@ -67,31 +80,41 @@ def __init__(
6780

6881
# TODO: double check initializations
6982
# Projections
70-
self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size))
83+
self.in_proj = Linear(
84+
td_model,
85+
td_inner_proj,
86+
bias=bias,
87+
weight_init_method=kaiming_init_(td_model.size),
88+
lr_scale=mamba_layer_lr_scale,
89+
)
7190
self.z_bias = (
7291
ParameterMeta.from_dims(
7392
(td_inner,),
7493
weight_decay=False,
7594
init_method=init_zeros_,
95+
lr_scale=mamba_layer_lr_scale,
7696
)
7797
if not bias
7898
else 0.0
7999
)
80100

81-
# Convolutional layer
82101
self.conv1d_weight = ParameterMeta.from_dims(
83102
(td_conv, TensorDim("1", 1), td_conv_kernel),
84103
init_method=init_uniform_(
85104
1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size)
86105
), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67
106+
lr_scale=mamba_layer_lr_scale,
107+
)
108+
self.conv1d_bias = ParameterMeta.from_dims(
109+
(td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale
87110
)
88-
self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight))
89111

90112
# D "skip" parameter
91113
self.D = ParameterMeta.from_dims(
92114
(td_n_qk_heads,),
93115
weight_decay=False,
94116
init_method=init_ones_,
117+
lr_scale=mamba_layer_lr_scale,
95118
)
96119

97120
# out_proj
@@ -100,6 +123,7 @@ def __init__(
100123
td_model,
101124
bias=bias,
102125
weight_init_method=kaiming_init_(td_inner.size),
126+
lr_scale=mamba_layer_lr_scale,
103127
)
104128

105129
@property
@@ -210,10 +234,25 @@ def forward(self, hidden_states, kwargs):
210234

211235
def convolutional_forward(self, xBC, padded_len):
212236
"""Convolutional layer forward pass for the full sequence."""
213-
xBC = causal_conv1d.causal_conv1d_fn(
214-
xBC.transpose(1, 2),
215-
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
216-
self.conv1d_bias,
217-
activation=None if self.activation_name == "identity" else self.activation_name,
218-
).transpose(1, 2)
237+
if causal_conv1d is None or self.activation_name not in [
238+
"silu",
239+
"swish",
240+
"identity",
241+
]:
242+
xBC = self.act(
243+
torch.nn.functional.conv1d(
244+
xBC.transpose(1, 2),
245+
self.conv1d_weight,
246+
bias=self.conv1d_bias,
247+
groups=self.conv1d_weight.shape[0],
248+
padding=self.conv_kernel_size - 1,
249+
)[..., :padded_len].transpose(1, 2)
250+
)
251+
else:
252+
xBC = causal_conv1d.causal_conv1d_fn(
253+
xBC.transpose(1, 2),
254+
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
255+
self.conv1d_bias,
256+
activation=None if self.activation_name == "identity" else self.activation_name,
257+
).transpose(1, 2)
219258
return xBC

0 commit comments

Comments
 (0)