Skip to content

Commit 39b018a

Browse files
authored
Update base image to 25.05 (#295)
1 parent e2f36d8 commit 39b018a

File tree

7 files changed

+89
-70
lines changed

7 files changed

+89
-70
lines changed

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
pip install pybind11
3434
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
3535
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
36-
pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
36+
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]"
3737
- name: Build the documentation
3838
run: mkdocs build
3939

Dockerfile

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# syntax=docker/dockerfile:1.7-labs
2-
FROM nvcr.io/nvidia/pytorch:24.11-py3
2+
FROM nvcr.io/nvidia/pytorch:25.05-py3
33

44
# Install dependencies.
55
RUN apt-get update \
@@ -24,13 +24,20 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to
2424
/usr/local/lib/python3.12/dist-packages \
2525
/usr/local/lib/python3.12/dist-packages/__pycache__
2626

27+
# The base image enforces versions for things like pytest for no good reason.
28+
ENV PIP_CONSTRAINT=""
29+
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
30+
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
31+
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
32+
RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/[email protected]"
33+
RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/[email protected]"
2734
# Copy dependency files with universal write permissions for all users.
2835
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
2936
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
3037
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/
3138

3239
# Install dependencies within the virtual environment.
33-
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,DEV]"
40+
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]"
3441

3542
# Copy the remaining source code with universal write permissions.
3643
COPY --chmod=777 ./Megatron-LM Megatron-LM

fast_llm/functional/triton/mlp.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525
from fast_llm.functional.triton.sparse_linear import output_sparse_matmul
2626
from fast_llm.tensor import param_get_and_unset_is_zero
2727

28-
# Triton requires global variables to be annotated with `constexpr`.
29-
_TritonActivationType: tl_constexpr = ActivationType
30-
3128

3229
@triton_jit()
3330
def triton_mlp_activation_forward_kernel(
@@ -50,18 +47,19 @@ def triton_mlp_activation_forward_kernel(
5047

5148
input_ = tl.load(input_ptr, mask=mask).to(tl.float32)
5249

53-
if activation_type == _TritonActivationType.gelu:
50+
# Triton doesn't like enums, so we use str instead of ActivationType.
51+
if activation_type == "gelu":
5452
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
5553
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
5654
out = input_ * 0.5 * (1.0 + tanh)
57-
elif activation_type == _TritonActivationType.silu:
55+
elif activation_type == "silu":
5856
out = input_ / (1 + tl.exp(-input_))
59-
elif activation_type == _TritonActivationType.relu:
57+
elif activation_type == "relu":
6058
out = tl.where(input_ > 0, input_, 0)
61-
elif activation_type == _TritonActivationType.squared_relu:
59+
elif activation_type == "squared_relu":
6260
relu_out = tl.where(input_ > 0, input_, 0)
6361
out = relu_out * relu_out
64-
elif activation_type == _TritonActivationType.identity:
62+
elif activation_type == "identity":
6563
out = input_
6664
else:
6765
tl.static_assert(False, activation_type)
@@ -100,28 +98,29 @@ def triton_mlp_activation_backward_kernel(
10098
input_ = tl.load(input_ptr, mask=mask).to(tl.float32)
10199
output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32)
102100

103-
if activation_type == _TritonActivationType.gelu:
101+
# Triton doesn't like enums, so we use str instead of ActivationType.
102+
if activation_type == "gelu":
104103
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
105104
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
106105
grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh)
107106
if gated or recompute:
108107
out = input_ * 0.5 * (1.0 + tanh)
109-
elif activation_type == _TritonActivationType.silu:
108+
elif activation_type == "silu":
110109
exp = tl.exp(-input_)
111110
sigma = 1 / (1 + exp)
112111
grad = sigma * sigma + (1 + input_) / (2 + exp + 1 / exp)
113112
if gated or recompute:
114113
out = input_ * sigma
115-
elif activation_type == _TritonActivationType.relu:
114+
elif activation_type == "relu":
116115
grad = tl.where(input_ > 0, 1, 0)
117116
if gated or recompute:
118117
out = tl.where(input_ > 0, input_, 0)
119-
elif activation_type == _TritonActivationType.squared_relu:
118+
elif activation_type == "squared_relu":
120119
relu_out = tl.where(input_ > 0, input_, 0)
121120
grad = 2 * relu_out
122121
if gated or recompute:
123122
out = relu_out * relu_out
124-
elif activation_type == _TritonActivationType.identity:
123+
elif activation_type == "identity":
125124
grad = 1
126125
if gated or recompute:
127126
out = input_

fast_llm/layers/ssm/discrete_mamba2.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33

44
import einops
5-
import mamba_ssm.ops.triton.ssd_combined
65
import torch
76

87
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
@@ -13,12 +12,22 @@
1312

1413
logger = logging.getLogger(__name__)
1514

15+
1616
try:
17-
import causal_conv1d
17+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_chunk_scan_combined # noqa
18+
19+
_mamba_available = True
1820
except ImportError:
19-
# this is needed since we cannot use causal_conv1d on B200 GPUs for now
20-
logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead")
21-
causal_conv1d = None
21+
_mamba_available = False
22+
23+
24+
try:
25+
from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa
26+
27+
_causal_conv1d_available = True
28+
except ImportError:
29+
_causal_conv1d_available = False
30+
2231

2332
"""
2433
This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py
@@ -148,6 +157,8 @@ def forward(self, hidden_states, kwargs):
148157
outputs["hidden_states"]: (B, L, D).
149158
outputs["state"]: inference cache.
150159
"""
160+
161+
assert _mamba_available
151162
input_ = hidden_states
152163
outputs = {}
153164
# assert state is None
@@ -201,7 +212,7 @@ def forward(self, hidden_states, kwargs):
201212
C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads)
202213

203214
# SSM forward
204-
result = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined(
215+
result = _mamba_chunk_scan_combined(
205216
x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1),
206217
dt=A_log,
207218
dt_softplus=True,
@@ -234,11 +245,18 @@ def forward(self, hidden_states, kwargs):
234245

235246
def convolutional_forward(self, xBC, padded_len):
236247
"""Convolutional layer forward pass for the full sequence."""
237-
if causal_conv1d is None or self.activation_name not in [
248+
if _causal_conv1d_available and self.activation_name in (
238249
"silu",
239250
"swish",
240251
"identity",
241-
]:
252+
):
253+
xBC = _causal_conv1d_fn(
254+
xBC.transpose(1, 2),
255+
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
256+
self.conv1d_bias,
257+
activation=None if self.activation_name == "identity" else self.activation_name,
258+
).transpose(1, 2)
259+
else:
242260
xBC = self.act(
243261
torch.nn.functional.conv1d(
244262
xBC.transpose(1, 2),
@@ -248,11 +266,4 @@ def convolutional_forward(self, xBC, padded_len):
248266
padding=self.conv_kernel_size - 1,
249267
)[..., :padded_len].transpose(1, 2)
250268
)
251-
else:
252-
xBC = causal_conv1d.causal_conv1d_fn(
253-
xBC.transpose(1, 2),
254-
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
255-
self.conv1d_bias,
256-
activation=None if self.activation_name == "identity" else self.activation_name,
257-
).transpose(1, 2)
258269
return xBC

fast_llm/layers/ssm/mamba_layer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Callable
33

44
import einops
5-
import mamba_ssm.ops.selective_scan_interface
65
import torch
76

87
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
@@ -11,6 +10,13 @@
1110
from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_
1211
from fast_llm.utils import get_lr_scale
1312

13+
try:
14+
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa
15+
16+
_mamba_available = True
17+
except ImportError:
18+
_mamba_available = False
19+
1420
"""
1521
Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba.
1622
For now it only supports training and not inference.
@@ -153,6 +159,7 @@ def __init__(
153159
self._return_input = return_input
154160

155161
def forward(self, hidden_states, kwargs):
162+
assert _mamba_available
156163
batch, seqlen, dim = hidden_states.shape
157164

158165
# We do matmul and transpose BLH -> HBL at the same time
@@ -167,7 +174,7 @@ def forward(self, hidden_states, kwargs):
167174
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
168175
# In the backward pass we write dx and dz next to each other to avoid torch.cat
169176
# not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s
170-
out = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn(
177+
out = _mamba_inner_fn(
171178
xz,
172179
self.conv1d_weight,
173180
self.conv1d_bias,

setup.cfg

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,58 @@ packages = find_namespace:
66
include_package_data = True
77
python_requires = >=3.12
88
install_requires =
9-
requests>=2.32.3
10-
PyYAML>=6.0.1
11-
pybind11>=2.5.0
12-
packaging>=24.1
9+
requests>=2.32.4
10+
PyYAML>=6.0.2
11+
pybind11>=2.13.6
12+
packaging>=25.0
1313

1414
[options.extras_require]
1515
# Required to use the main functionality of Fast-LLM
1616
# To install on cpu environment (ex. for IDE support):
1717
# FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation
1818
CORE =
1919
# Available through the nvidia base image
20-
torch>=2.5.0
20+
torch>=2.7.0
2121
# Numpy major needs to match torch
22-
numpy>=1.24.4,<2.0.0
22+
numpy>=1.26.4,<2.0.0
2323
# Used for checkpoints
24-
safetensors>=0.4.4
24+
safetensors>=0.5.3
2525
# Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation
26-
flash-attn==2.7.2.post1
27-
mamba_ssm==2.2.4
26+
flash-attn==2.7.3
2827

2928

30-
# Required for some optional features and tools.
29+
# Small packages required for some optional features and tools.
3130
OPTIONAL =
32-
# Huggingface tools
33-
transformers>=4.44.2
34-
hf-transfer>=0.1.8
35-
datasets>=3.1.0
36-
huggingface-hub>=0.28.1
3731
# Weights and biases
38-
wandb>=0.17.7
32+
wandb>=0.20.1
3933
# Hydra
4034
hydra-core>=1.3.2
4135
omegaconf>=2.3.0
4236
# Miscellaneous
43-
requests>=2.32.3
44-
tqdm>=4.66.3
45-
# For causal_conv1d
46-
causal_conv1d>=1.4.0
37+
tqdm>=4.67.1
38+
39+
# Huggingface tools
40+
HUGGINGFACE =
41+
transformers>=4.52.4
42+
hf-transfer>=0.1.9
43+
datasets>=3.6.0
44+
huggingface-hub>=0.32.6
45+
46+
# Required to run SSMs
47+
# To install on cpu environment (ex. for IDE support):
48+
# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation
49+
SSM =
50+
mamba_ssm[causal-conv1d]==2.2.4
4751

4852
DEV =
4953
# Pre-commit git hook
50-
pre-commit>=4.0.1
54+
pre-commit>=4.2.0
5155
# Required for testing
52-
pytest>=8.3.2
56+
pytest>=8.4.0
5357
pytest-depends>=1.0.1
54-
pytest-xdist>=3.6.1
58+
pytest-xdist>=3.7.0
5559
# Somehow needed for Megatron to work with base image 24.11
56-
setuptools>=75.6.0
60+
setuptools>=80.9.0
5761

5862
# Required for building the documentation
5963
DOCS =

tests/test_ssms.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,15 @@
1414
from fast_llm.engine.schedule.schedule import Schedule
1515
from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames
1616
from fast_llm.layers.ssm.config import SSMBlockType
17+
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2
18+
from fast_llm.layers.ssm.llamba_block import LlambaBlock
19+
from fast_llm.layers.ssm.mamba_layer import MambaLayer
1720
from fast_llm.layers.transformer.config import TransformerKwargs
1821
from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat
1922
from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat
23+
from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel
2024
from tests.common import get_hybrid_config, materialize_meta_tensors
2125

22-
try:
23-
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2
24-
from fast_llm.layers.ssm.llamba_block import LlambaBlock
25-
from fast_llm.layers.ssm.mamba_layer import MambaLayer
26-
from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel
27-
except Exception:
28-
MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = (
29-
None,
30-
None,
31-
None,
32-
None,
33-
)
34-
3526
try:
3627
from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel
3728
except ImportError:

0 commit comments

Comments
 (0)