Skip to content

Commit 67abaca

Browse files
authored
build: Install flash-attn (#335)
Signed-off-by: oliver könig <okoenig@nvidia.com>
1 parent 0b3874b commit 67abaca

File tree

3 files changed

+371
-338
lines changed

3 files changed

+371
-338
lines changed

docker/common/install.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ main() {
135135
uv sync \
136136
--link-mode copy \
137137
--locked \
138+
--extra fa \
138139
--all-groups ${UV_ARGS[@]}
139140
# Install the package
140141
uv pip install --no-deps -e .

pyproject.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ trtllm = [
8888
"cuda-python~=12.8.0",
8989
]
9090
trt-onnx = ["tensorrt==10.11.0.33", "transformers==4.51.3", "onnx==1.18.0"]
91+
fa = ["flash-attn==2.8.1"]
9192

9293
[dependency-groups]
9394
# This is a default group so that we install these even with bare `uv sync`
@@ -109,6 +110,7 @@ nemo-run = ["nemo-run"]
109110

110111
[tool.uv.sources]
111112
xformers = [{ index = "pytorch-cu128" }]
113+
torch = [{ index = "pytorch-cu128" }]
112114
vllm = [
113115
{ index = "pytorch-cu128", marker = "python_version < '3.9' and platform_machine == 'x86_64'" },
114116
{ index = "pypi", marker = "platform_machine == 'aarch64'" },
@@ -119,7 +121,11 @@ transformer-engine = { git = "https://github.com/NVIDIA/TransformerEngine.git",
119121

120122
[tool.uv]
121123
# Currently, TE must be built with no build-isolation b/c it requires torch
122-
no-build-isolation-package = ["transformer-engine", "transformer-engine-torch"]
124+
no-build-isolation-package = [
125+
"transformer-engine",
126+
"transformer-engine-torch",
127+
"flash-attn",
128+
]
123129
# Always apply the build group since dependencies like TE/mcore/nemo-run require build dependencies
124130
# and this lets us assume they are implicitly installed with a simply `uv sync`. Ideally, we'd
125131
# avoid including these in the default dependency set, but for now it's required.
@@ -137,6 +143,11 @@ override-dependencies = [
137143
]
138144
prerelease = "allow"
139145

146+
# Needed when building from source
147+
[[tool.uv.dependency-metadata]]
148+
name = "flash-attn"
149+
requires-dist = ["torch", "einops", "setuptools", "psutil", "ninja"]
150+
140151
[[tool.uv.index]]
141152
name = "pypi"
142153
url = "https://pypi.org/simple"
@@ -205,4 +216,4 @@ convention = "google"
205216
# Ignore all files that end in `_test.py`.
206217
"*_test.py" = ["D"]
207218
# Ignore F401 (import but unused) in __init__.py
208-
"__init__.py" = ["F401"]
219+
"__init__.py" = ["F401"]

0 commit comments

Comments
 (0)