|
5 | 5 | utils::{destruct, distinct, offset_ptr, strides}, |
6 | 6 | }; |
7 | 7 | use cuda::{CaptureStream, GraphExec, Module, Stream, VirByte}; |
8 | | -use flash_attn::attention::{FlashAttnCfg, KVPage, KernelReq, Strides2D}; |
| 8 | +use flash_attn::attention::{AttnType, FlashAttnCfg, KVPage, KernelReq, Strides2D}; |
9 | 9 | use ggus::ggml_quants::f16; |
10 | 10 | use nn::{Arg, Named, Tensor, digit_layout::types}; |
11 | 11 | use regex::Regex; |
12 | | -use std::{fmt, sync::LazyLock}; |
| 12 | +use std::{fmt, ptr::null, sync::LazyLock}; |
13 | 13 |
|
14 | 14 | pub(super) enum Step<'ctx> { |
15 | 15 | Graph(GraphExec<'ctx>, Box<[Tensor<*const VirByte, 2>]>), |
@@ -144,31 +144,28 @@ impl<'ctx> Handle<'ctx> { |
144 | 144 | reqs: &[Req<Tensor<*const VirByte, 2>>], |
145 | 145 | stream: &Stream, |
146 | 146 | ) { |
| 147 | + use ::flash_attn::attention::cuda::code as flash_attn_code; |
147 | 148 | let Attention { q, k, v, o, .. } = attn; |
148 | 149 | let dt = distinct(&[q.dt(), k.dt(), v.dt(), o.dt()]).unwrap(); |
149 | 150 | // 编译 |
150 | 151 | let key = [ModuleKey::Text("flash-attn"), ModuleKey::Type(dt)].into_iter(); |
151 | | - let [t_compute, t_data] = match dt { |
152 | | - types::F16 => ["float", "half"], |
153 | | - _ => todo!(), |
154 | | - }; |
155 | | - let module = self.compile(key.collect(), || { |
156 | | - ::flash_attn::attention::cuda::code(t_compute, t_data) |
157 | | - }); |
158 | 152 | match dt { |
159 | | - types::F16 => launch_attn_typed::<f16>(attn, reqs, module, stream), |
| 153 | + types::F16 => { |
| 154 | + let module = self.compile(key.collect(), || flash_attn_code::<f16>()); |
| 155 | + launch_attn_typed::<f16>(attn, reqs, module, stream) |
| 156 | + } |
160 | 157 | _ => todo!(), |
161 | 158 | } |
162 | 159 | } |
163 | 160 | } |
164 | 161 |
|
165 | | -fn launch_attn_typed<T: Copy>( |
| 162 | +fn launch_attn_typed<T: ::flash_attn::attention::cuda::NVDT>( |
166 | 163 | attn: &Attention, |
167 | 164 | reqs: &[Req<Tensor<*const VirByte, 2>>], |
168 | 165 | module: &Module, |
169 | 166 | stream: &Stream, |
170 | 167 | ) { |
171 | | - const TILE_SEQ: usize = 32; |
| 168 | + const TILE_SEQ: usize = 8; |
172 | 169 | const TILE_CTX: usize = 32; |
173 | 170 |
|
174 | 171 | let Attention { iblk, q, k, v, o } = attn; |
@@ -230,25 +227,10 @@ fn launch_attn_typed<T: Copy>( |
230 | 227 | }) |
231 | 228 | }) |
232 | 229 | .collect::<Box<_>>(); |
233 | | - // 生成 mask |
234 | | - let masks = reqs |
235 | | - .iter() |
236 | | - .map(|req| { |
237 | | - let Req { pos, seq: n, .. } = req; |
238 | | - let s = pos + n; |
239 | | - let s_ceil = s.div_ceil(TILE_CTX) * TILE_CTX; |
240 | | - // 注意力掩码 |
241 | | - let mask = (0..n * s_ceil) |
242 | | - .map(|i| i % s_ceil <= s - n + i / s_ceil) |
243 | | - .collect::<Box<_>>(); |
244 | | - stream.from_host(&mask) |
245 | | - }) |
246 | | - .collect::<Box<_>>(); |
247 | 230 | // 为每个请求的每个头生成 block |
248 | 231 | let reqs_ = reqs |
249 | 232 | .iter() |
250 | | - .zip(&masks) |
251 | | - .scan((0, 0), |(seq, page), (req, mask)| { |
| 233 | + .scan((0, 0), |(seq, page), req| { |
252 | 234 | let &Req { |
253 | 235 | ref cache, |
254 | 236 | pos, |
@@ -288,9 +270,10 @@ fn launch_attn_typed<T: Copy>( |
288 | 270 | kv_strides, |
289 | 271 | o: offset_ptr(&o).cast_mut().cast(), |
290 | 272 | o_strides, |
291 | | - mask: mask.as_ptr().cast(), |
292 | 273 | n, |
293 | 274 | s: pos + n, |
| 275 | + ty: AttnType::Causal, |
| 276 | + mask: null(), |
294 | 277 | }) |
295 | 278 | }) |
296 | 279 | .collect::<Box<_>>(); |
|
0 commit comments