Skip to content

Commit 6063373

Browse files
committed
feat(llama.cu): 移除 operators-rs,接入 flash attn
Signed-off-by: YdrMaster <[email protected]>
1 parent 2c1d870 commit 6063373

31 files changed

+241
-196
lines changed

Cargo.lock

Lines changed: 20 additions & 56 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama.cu/Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ version = "0.0.0"
44
edition.workspace = true
55

66
[dependencies]
7-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "88c58bd", default-features = false, features = [
8-
"nvidia-gpu",
9-
] }
7+
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
8+
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
9+
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
10+
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "57176f5" }
1011
nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "171c5b0" }
1112
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
1213
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }

llama.cu/src/exec/engine.rs

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@ use crate::{
99
handle::Handle,
1010
op::{FastEmbedding, random_sample::KVPair},
1111
};
12+
use cuda::{ContextResource, CurrentCtx, Device, Event, HostMem};
1213
use nn::{Distribution, LLaMA, Tensor};
13-
use operators::{
14-
Operator,
15-
attention_kv_cached::cuda::Operator as Attn,
16-
cuda::{ContextResource, CurrentCtx, Device, Event, Gpu, HostMem},
17-
};
1814
use std::{
1915
ffi::c_int,
2016
iter::zip,
@@ -30,7 +26,7 @@ use std::{
3026
use tokeneer::utok;
3127

3228
#[cfg(nccl)]
33-
use operators::nccl::{Communicator, CommunicatorGroup};
29+
use nccl::{Communicator, CommunicatorGroup};
3430

3531
type Stub = SessionStub<CacheParts>;
3632

@@ -222,16 +218,13 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
222218
} = self;
223219

224220
dev.set_mempool_threshold(u64::MAX);
225-
let gpu = Gpu::new(dev.retain_primary(), Default::default());
226-
let attn = Attn::new(&gpu);
227-
gpu.apply(|ctx| {
221+
dev.retain_primary().apply(|ctx| {
228222
let mut handle = handle(ctx);
229223
let mut models = ModelGroup::new(
230224
llama,
231225
dist,
232226
progress,
233227
config,
234-
attn,
235228
&mut handle,
236229
barrier.as_deref(),
237230
);
@@ -373,21 +366,12 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
373366
..
374367
} = self;
375368

376-
dev.set_mempool_threshold(u64::MAX);
377-
let gpu = Gpu::new(dev.retain_primary(), Default::default());
378-
let attn = Attn::new(&gpu);
379369
let barrier = barrier.unwrap();
380-
gpu.apply(|ctx| {
370+
dev.set_mempool_threshold(u64::MAX);
371+
dev.retain_primary().apply(|ctx| {
381372
let mut handle = Handle::with_comm(ctx, comm);
382-
let mut models = ModelGroup::new(
383-
llama,
384-
dist,
385-
progress,
386-
config,
387-
attn,
388-
&mut handle,
389-
Some(&barrier),
390-
);
373+
let mut models =
374+
ModelGroup::new(llama, dist, progress, config, &mut handle, Some(&barrier));
391375

392376
let stream = ctx.stream();
393377
loop {

llama.cu/src/exec/group.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
use super::{CacheParts, Progress, model::ModelExec, upos};
22
use crate::{batch::Req, handle::Handle, load::load_weight, memory::MemPages};
3+
use cuda::{DevByte, DevMem, Stream, VirByte};
34
use nn::{
45
Distribution, Graph, GraphBuilder, LLaMA, NNGraph, Tensor, TensorMeta, digit_layout::types, op,
56
};
6-
use operators::{
7-
attention_kv_cached::cuda::Operator as Attn,
8-
cuda::{DevByte, DevMem, Stream, VirByte},
9-
};
107
use std::{
118
collections::BTreeMap,
129
num::{NonZero, NonZeroUsize},
@@ -16,7 +13,6 @@ use tokeneer::utok;
1613

1714
pub(crate) struct ModelGroup<'ctx> {
1815
internal: Internal<'ctx>,
19-
attn: Attn,
2016
pages: MemPages,
2117
_weight: DevMem<'ctx>,
2218
}
@@ -36,7 +32,6 @@ impl<'ctx> ModelGroup<'ctx> {
3632

3733
config: ModelGroupConfig<T>,
3834

39-
attn: Attn,
4035
handle: &mut Handle<'ctx>,
4136
barrier: Option<&Barrier>,
4237
) -> Self {
@@ -82,7 +77,6 @@ impl<'ctx> ModelGroup<'ctx> {
8277
let models_with_one_dyn = Internal::new(graph, static_models, dyn_cache_size);
8378
Self {
8479
internal: models_with_one_dyn,
85-
attn,
8680
pages,
8781
_weight,
8882
}
@@ -125,10 +119,7 @@ impl<'ctx> ModelGroup<'ctx> {
125119
stream: &Stream<'ctx>,
126120
) -> Tensor<*const VirByte, 2> {
127121
let Self {
128-
internal,
129-
attn,
130-
pages,
131-
..
122+
internal, pages, ..
132123
} = self;
133124

134125
let mut reqs = reqs
@@ -142,7 +133,8 @@ impl<'ctx> ModelGroup<'ctx> {
142133
let reqs = reqs
143134
.iter_mut()
144135
.map(|req| {
145-
req.cache.update(req.pos + req.seq, pages);
136+
req.cache
137+
.update((req.pos + req.seq).div_ceil(32) * 32, pages);
146138
Req {
147139
cache: req.cache.as_tensor(),
148140
pos: req.pos,
@@ -154,7 +146,7 @@ impl<'ctx> ModelGroup<'ctx> {
154146
internal
155147
.get_mut(&key)
156148
.unwrap()
157-
.launch(attn, handle, &reqs, stream)
149+
.launch(handle, &reqs, stream)
158150
}
159151
}
160152

llama.cu/src/exec/kv_cache.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use crate::memory::MemPages;
1+
use crate::memory::MemPages;
2+
use cuda::{VirByte, VirMem};
23
use nn::Tensor;
3-
use operators::cuda::{VirByte, VirMem};
44

55
pub(crate) struct KVCache {
66
/// 基于虚地址的 cache 张量

llama.cu/src/exec/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
batch::{Session as Session_, SessionId},
1313
op::random_sample::KVPair,
1414
};
15-
use operators::cuda::{ContextSpore, CurrentCtx, DevMemSpore, EventSpore, Stream};
15+
use cuda::{ContextSpore, CurrentCtx, DevMemSpore, EventSpore, Stream};
1616
use std::collections::BTreeMap;
1717
use tokeneer::utok;
1818

llama.cu/src/exec/model.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ use crate::{
66
utils::{self, destruct},
77
};
88
use bytesize::ByteSize;
9+
use cuda::{DevByte, Stream, VirByte, VirMem};
910
use log::trace;
1011
use nn::{NNGraph, Tensor};
11-
use operators::{
12-
attention_kv_cached::cuda::Operator as Attn,
13-
cuda::{DevByte, Stream, VirByte, VirMem},
14-
};
1512
use std::time::Instant;
1613

1714
pub(super) struct ModelExec<'ctx> {
@@ -100,7 +97,6 @@ impl ModelExec<'_> {
10097

10198
pub fn launch(
10299
&mut self,
103-
attn: &Attn,
104100
handle: &mut Handle,
105101
reqs: &[Req<Tensor<*const VirByte, 2>>],
106102
stream: &Stream,
@@ -117,7 +113,7 @@ impl ModelExec<'_> {
117113
std::process::exit(0);
118114
}
119115
}
120-
Step::Attention(box_) => handle.launch_attn(attn, box_, reqs, stream),
116+
Step::Attention(box_) => handle.launch_attn(box_, reqs, stream),
121117
Step::Exec(exec) => handle.launch_nn_exec(exec, stream),
122118
}
123119
}

llama.cu/src/exec/output_head.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use crate::{
1+
use crate::{
22
handle::Handle,
33
load::WeightLoader,
44
op::{self, Operator as _},
55
utils::dims,
66
};
7+
use cuda::{CurrentCtx, DevMem, Stream, VirByte};
78
use nn::{Arg, Linear, NormType, Normalization, Tensor, digit_layout::types};
8-
use operators::cuda::{CurrentCtx, DevMem, Stream, VirByte};
99
use tokeneer::utok;
1010

1111
pub(super) struct OutputHead<'ctx> {

llama.cu/src/exec/sample_manager.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use crate::{
1+
use crate::{
22
SessionId,
33
batch::SampleInfo,
44
op::random_sample::{KV_PAIR, KVPair, LogitsModifier, RandomSample},
55
utils::dims,
66
};
7+
use cuda::{CurrentCtx, DevByte, DevMem, Stream};
78
use nn::Tensor;
8-
use operators::cuda::{CurrentCtx, DevByte, DevMem, Stream};
99
use std::{collections::BTreeMap, ptr::null};
1010
use tokeneer::utok;
1111

0 commit comments

Comments
 (0)