Skip to content

Commit 2e37745

Browse files
tjohnson31415njhill
andcommitted
bump: update to PyTorch 2.2 and Flash 2.5.2
Co-authored-by: Nick Hill <[email protected]>
1 parent 52170da commit 2e37745

File tree

6 files changed

+24
-47
lines changed

6 files changed

+24
-47
lines changed

Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
## Global Args #################################################################
22
ARG BASE_UBI_IMAGE_TAG=9.3-1552
33
ARG PROTOC_VERSION=25.2
4-
#ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
5-
ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
6-
ARG PYTORCH_VERSION=2.3.0.dev20240125
4+
ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
5+
# ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
6+
ARG PYTORCH_VERSION=2.2.0
77
ARG PYTHON_VERSION=3.11
88

99
## Base Layer ##################################################################
@@ -205,7 +205,7 @@ RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu11
205205

206206
## Build flash attention v2 ####################################################
207207
FROM python-builder as flash-att-v2-builder
208-
ARG FLASH_ATT_VERSION=v2.3.6
208+
ARG FLASH_ATT_VERSION=v2.5.2
209209

210210
WORKDIR /usr/src/flash-attention-v2
211211

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,20 +243,16 @@ def forward(
243243
self.rotary_emb(query, cos, sin)
244244
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
245245

246-
# output tensor
247-
attn_output = torch.empty_like(query)
248-
249246
# Prefill
250247
if layer_past_present_indices is None:
251248
# Copy to layer past
252249
layer_past[...] = kv
253250

254251
# flash attention
255-
attention(
252+
attn_output = attention(
256253
query,
257254
torch.select(kv, dim=1, index=0),
258255
torch.select(kv, dim=1, index=1),
259-
attn_output,
260256
cu_seqlens,
261257
max_s,
262258
self.softmax_scale,
@@ -267,11 +263,10 @@ def forward(
267263
layer_past[layer_past_present_indices] = kv
268264

269265
# flash attention
270-
attention(
266+
attn_output = attention(
271267
query,
272268
layer_past[:, 0],
273269
layer_past[:, 1],
274-
attn_output,
275270
cu_seqlens,
276271
max_s,
277272
self.softmax_scale,
@@ -280,7 +275,7 @@ def forward(
280275
False,
281276
)
282277

283-
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
278+
return self.o_proj(attn_output.reshape(-1, self.num_heads * self.head_size))
284279

285280

286281
class LlamaMLP(nn.Module):

server/text_generation_server/models/custom_modeling/flash_neox_modeling.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,16 @@ def forward(
135135

136136
query = qkv[:, 0]
137137

138-
# output tensor
139-
attn_output = torch.empty_like(query)
140-
141138
# Prefill
142139
if layer_past_present_indices is None:
143140
# Copy to layer past
144141
layer_past[...] = qkv[:, 1:]
145142

146143
# flash attention
147-
attention(
144+
attn_output = attention(
148145
query,
149146
qkv[:, 1],
150147
qkv[:, 2],
151-
attn_output,
152148
cu_seqlens,
153149
max_s,
154150
self.softmax_scale,
@@ -159,11 +155,10 @@ def forward(
159155
layer_past[layer_past_present_indices] = qkv[:, 1:]
160156

161157
# flash attention
162-
attention(
158+
attn_output = attention(
163159
query,
164160
layer_past[:, 0],
165161
layer_past[:, 1],
166-
attn_output,
167162
cu_seqlens,
168163
max_s,
169164
self.softmax_scale,
@@ -172,7 +167,7 @@ def forward(
172167
False,
173168
)
174169

175-
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
170+
return self.dense(attn_output.reshape(-1, self.num_heads * self.head_size))
176171

177172

178173
class FlashMLP(nn.Module):

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,16 @@ def forward(
175175
self.rotary_emb(query, cos, sin)
176176
self.rotary_emb(kv[:, 0], cos, sin)
177177

178-
# output
179-
attn_output = torch.empty_like(query)
180-
181178
# Prefill
182179
if layer_past_present_indices is None:
183180
# Copy to layer past
184181
layer_past[...] = kv
185182

186183
# flash attention
187-
attention(
184+
attn_output = attention(
188185
query,
189186
torch.select(kv, dim=1, index=0),
190187
torch.select(kv, dim=1, index=1),
191-
attn_output,
192188
cu_seqlens,
193189
max_s,
194190
self.softmax_scale,
@@ -199,11 +195,10 @@ def forward(
199195
layer_past[layer_past_present_indices] = kv
200196

201197
# flash attention
202-
attention(
198+
attn_output = attention(
203199
query,
204200
layer_past[:, 0],
205201
layer_past[:, 1],
206-
attn_output,
207202
cu_seqlens,
208203
max_s,
209204
self.softmax_scale,
@@ -212,7 +207,7 @@ def forward(
212207
False,
213208
)
214209

215-
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
210+
return self.dense(attn_output.reshape(-1, self.num_heads * self.head_size))
216211

217212

218213
class FlashRWLargeAttention(torch.nn.Module):
@@ -286,20 +281,16 @@ def forward(
286281
self.rotary_emb(query, cos, sin)
287282
self.rotary_emb(kv[:, :, 0], cos, sin)
288283

289-
# output
290-
attn_output = torch.empty_like(query)
291-
292284
# Prefill
293285
if layer_past_present_indices is None:
294286
# Copy to layer past
295287
layer_past[...] = kv
296288

297289
# flash attention
298-
attention(
290+
attn_output = attention(
299291
query,
300292
torch.select(kv, dim=2, index=0),
301293
torch.select(kv, dim=2, index=1),
302-
attn_output,
303294
cu_seqlens,
304295
max_s,
305296
self.softmax_scale,
@@ -310,11 +301,10 @@ def forward(
310301
layer_past[layer_past_present_indices] = kv
311302

312303
# flash attention
313-
attention(
304+
attn_output = attention(
314305
query,
315306
layer_past[:, :, 0],
316307
layer_past[:, :, 1],
317-
attn_output,
318308
cu_seqlens,
319309
max_s,
320310
self.softmax_scale,

server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,20 +243,16 @@ def forward(
243243
query = query.view(-1, self.num_heads, self.head_size)
244244
key_value = key_value.view(-1, 2, 1, self.head_size)
245245

246-
# output
247-
attn_output = torch.empty_like(query)
248-
249246
# Prefill
250247
if layer_past_present_indices is None:
251248
# Copy to layer past
252249
layer_past[...] = key_value
253250

254251
# flash attention
255-
attention(
252+
attn_output = attention(
256253
query,
257254
torch.select(key_value, dim=1, index=0),
258255
torch.select(key_value, dim=1, index=1),
259-
attn_output,
260256
cu_seqlens,
261257
max_s,
262258
self.softmax_scale,
@@ -267,11 +263,10 @@ def forward(
267263
layer_past[layer_past_present_indices] = key_value
268264

269265
# flash attention
270-
attention(
266+
attn_output = attention(
271267
query,
272268
layer_past[:, 0],
273269
layer_past[:, 1],
274-
attn_output,
275270
cu_seqlens,
276271
max_s,
277272
self.softmax_scale,
@@ -280,7 +275,7 @@ def forward(
280275
False,
281276
)
282277

283-
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
278+
return self.c_proj(attn_output.reshape(-1, self.num_heads * self.head_size))
284279

285280

286281
class MLP(nn.Module):

server/text_generation_server/utils/flash_attn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def attention(
4444
q,
4545
k,
4646
v,
47-
out,
4847
cu_seqlens,
4948
max_s,
5049
softmax_scale,
@@ -61,10 +60,11 @@ def attention(
6160
q,
6261
k,
6362
v,
64-
out,
63+
None,
6564
cu_seqlens_q,
6665
cu_seqlens,
6766
None,
67+
None,
6868
max_s_q,
6969
max_s,
7070
0.0,
@@ -75,7 +75,7 @@ def attention(
7575
-1,
7676
False,
7777
None,
78-
)
78+
)[0]
7979

8080
if HAS_FLASH_ATTN:
8181
# Flash attention v1 requires q, k and v to have the same number of heads
@@ -104,7 +104,8 @@ def attention(
104104
.reshape(original_shape[0], -1, original_shape[2])
105105
)
106106

107-
return flash_attn_cuda.fwd(
107+
out = torch.empty_like(q)
108+
flash_attn_cuda.fwd(
108109
q,
109110
k,
110111
v,
@@ -121,5 +122,6 @@ def attention(
121122
0,
122123
None,
123124
)
125+
return out
124126

125127
raise NotImplementedError("flash attention is not installed")

0 commit comments

Comments
 (0)