Skip to content

Commit 816101f

Browse files
committed
perf(llama): 优化分布式切分和参数加载
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 8729b3d commit 816101f

File tree

9 files changed

+364
-450
lines changed

9 files changed

+364
-450
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ members = [
77

88
"models/llama/common",
99
"models/llama/common-cpu",
10-
"models/llama/opencl",
11-
"models/llama/infini",
10+
# "models/llama/opencl",
11+
# "models/llama/infini",
1212
"models/llama/cuda",
1313

1414
"models/clip/common",
@@ -34,7 +34,7 @@ itertools = "0.13"
3434
env_logger = "0.11"
3535
build-script-cfg = "0.0"
3636

37-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "df027a4", default-features = false }
37+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "fd8f972", default-features = false }
3838

3939
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
4040
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }

common/src/lib.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
use std::{borrow::Borrow, collections::HashMap, hash::Hash, ops::Deref};
1+
use std::{
2+
borrow::Borrow,
3+
collections::HashMap,
4+
hash::Hash,
5+
ops::{Deref, Range},
6+
};
27

38
pub enum Contiguous<'a, T> {
49
Borrowed(&'a [u8]),
@@ -52,3 +57,42 @@ impl<K: Eq + Hash, V> Slab<K, V> {
5257
self.0.entry(key).or_default().push(value);
5358
}
5459
}
60+
61+
#[derive(Clone, Copy, Debug)]
62+
pub struct Distribution {
63+
pub start: usize,
64+
pub len: usize,
65+
pub total: usize,
66+
}
67+
68+
impl Distribution {
69+
pub const MONO: Self = Self {
70+
start: 0,
71+
len: 1,
72+
total: 1,
73+
};
74+
}
75+
76+
pub struct WeightMemCalculator {
77+
align: usize,
78+
size: usize,
79+
}
80+
81+
impl WeightMemCalculator {
82+
#[inline]
83+
pub const fn new(align: usize) -> Self {
84+
Self { align, size: 0 }
85+
}
86+
87+
#[inline]
88+
pub const fn size(&self) -> usize {
89+
self.size
90+
}
91+
92+
#[inline]
93+
pub fn push(&mut self, size: usize) -> Range<usize> {
94+
let start = self.size.div_ceil(self.align) * self.align;
95+
self.size = start + size;
96+
start..self.size
97+
}
98+
}

models/llama/common-cpu/src/infer.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::{Operators, RandomSample, Weights};
2+
use common::Distribution;
23
use gguf::GGufModel;
34
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
45
use operators::{
@@ -59,16 +60,18 @@ fn test_infer() {
5960
let _workers = zip(lens, seeds)
6061
.enumerate()
6162
.scan(0, |start, (id, (len, seed))| {
62-
let range = *start..*start + len;
63-
*start = range.end;
64-
65-
let mut meta = model.meta.clone();
66-
meta.distribute(range.clone(), count);
63+
let dist = Distribution {
64+
start: *start,
65+
len,
66+
total: count,
67+
};
68+
*start += len;
6769

70+
let meta = model.meta.distribute(dist);
6871
let model = &model;
6972
Some(s.spawn(move || {
7073
let WorkerSeed { node, tasks } = seed;
71-
let weights = Weights::new(model, range, count);
74+
let weights = Weights::new(model, dist);
7275
let mut worker = Worker::new(id, &node, meta.clone(), weights);
7376
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new);
7477
let sin_cos = <Operators as llama::Operators>::build_sin_cos(

models/llama/common-cpu/src/lib.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use common::Contiguous;
1+
use common::{Contiguous, Distribution};
22
use llama::{
33
ext::ggml_quants::{self, digit_layout::DigitLayout, f16, DataBlock, QuantExt},
4-
BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor,
4+
LlamaBlkStorage, LlamaBlkWeight, LlamaStorage, Tensor,
55
TensorUsage::Computation,
66
WeightLoader,
77
};
@@ -16,7 +16,7 @@ use std::{
1616
cell::{Ref, RefCell},
1717
marker::PhantomData,
1818
mem::size_of,
19-
ops::{Deref, Range, RangeBounds},
19+
ops::{Deref, Range},
2020
ptr::copy_nonoverlapping,
2121
slice::{from_raw_parts, from_raw_parts_mut},
2222
};
@@ -41,7 +41,7 @@ pub struct Weights<'w> {
4141

4242
pub struct WeightCache {
4343
cache: Blob,
44-
cached_weight: BlkWeight,
44+
cached_weight: LlamaBlkWeight,
4545
cached_weight_iblk: usize,
4646
}
4747

@@ -85,11 +85,7 @@ where
8585
}
8686

8787
impl<'w> Weights<'w> {
88-
pub fn new(
89-
model: &'w LlamaStorage<&'w [u8]>,
90-
range: impl RangeBounds<usize> + Clone,
91-
count: usize,
92-
) -> Self {
88+
pub fn new(model: &'w LlamaStorage<&'w [u8]>, dist: Distribution) -> Self {
9389
let LlamaStorage {
9490
meta,
9591
output_norm,
@@ -100,11 +96,17 @@ impl<'w> Weights<'w> {
10096

10197
let blks = blocks
10298
.iter()
103-
.map(|blk| blk.distribute(meta, range.clone(), count, Blob::new))
99+
.map(|blk| {
100+
blk.into_vec()
101+
.into_iter()
102+
.map(|(which, data)| {
103+
(which, meta.distribute_data(which, data, dist, Blob::new))
104+
})
105+
.collect::<LlamaBlkStorage<_>>()
106+
})
104107
.collect::<Box<_>>();
105108

106-
let mut meta = meta.clone();
107-
meta.distribute(range.clone(), count);
109+
let meta = meta.distribute(dist);
108110
let size_qkv = meta.attn_qkv(Computation).take();
109111
let size_o = meta.attn_o(Computation).take();
110112
let size_gate_up = meta.ffn_gate_up(Computation).take();
@@ -113,7 +115,7 @@ impl<'w> Weights<'w> {
113115
let weight_cache = if meta.dt_embd == meta.dt_linear {
114116
RefCell::new(WeightCache {
115117
cache: Blob::new(0),
116-
cached_weight: BlkWeight::AttnQKV,
118+
cached_weight: LlamaBlkWeight::AttnQKV,
117119
cached_weight_iblk: 0,
118120
})
119121
} else {
@@ -131,7 +133,7 @@ impl<'w> Weights<'w> {
131133

132134
RefCell::new(WeightCache {
133135
cache,
134-
cached_weight: BlkWeight::AttnQKV,
136+
cached_weight: LlamaBlkWeight::AttnQKV,
135137
cached_weight_iblk: 0,
136138
})
137139
};
@@ -207,7 +209,7 @@ impl WeightLoader for Weights<'_> {
207209
#[inline]
208210
fn load_blk(
209211
&self,
210-
which: BlkWeight,
212+
which: LlamaBlkWeight,
211213
iblk: usize,
212214
_queue: &QueueOf<Self::Hardware>,
213215
) -> Self::Weight<'_> {
@@ -233,10 +235,10 @@ impl WeightLoader for Weights<'_> {
233235
ffn_down,
234236
} = &blks[iblk];
235237

236-
use BlkWeight::{
238+
use Dequant::{Borrowed, Cached};
239+
use LlamaBlkWeight::{
237240
AttnNorm, AttnO, AttnQKV, AttnQKVBias, FfnDown, FfnGateInp, FfnGateUp, FfnNorm,
238241
};
239-
use Dequant::{Borrowed, Cached};
240242

241243
#[rustfmt::skip]
242244
match which {
@@ -301,7 +303,7 @@ impl WeightLoader for Weights<'_> {
301303

302304
fn load_moe<'a>(
303305
&'a self,
304-
which: BlkWeight,
306+
which: LlamaBlkWeight,
305307
iblk: usize,
306308
iexp: usize,
307309
_queue: &'a QueueOf<Self::Hardware>,
@@ -315,8 +317,8 @@ impl WeightLoader for Weights<'_> {
315317
} = self;
316318
assert_eq!(dt_embd, dt_mat);
317319
let w = match which {
318-
BlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up,
319-
BlkWeight::FfnDown => &*blks[iblk].ffn_down,
320+
LlamaBlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up,
321+
LlamaBlkWeight::FfnDown => &*blks[iblk].ffn_down,
320322
_ => unreachable!(),
321323
};
322324
let one = w.len() / nexp;

models/llama/common/src/compute.rs

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{args::Args, LlamaMeta};
1+
use super::{args::Args, LlamaBlkWeight, LlamaMeta};
22
use gguf::ggml_quants::{
33
digit_layout::{types as ty, DigitLayout},
44
f16,
@@ -53,18 +53,6 @@ pub trait Operators {
5353
}
5454
}
5555

56-
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
57-
pub enum BlkWeight {
58-
AttnNorm,
59-
AttnQKV,
60-
AttnQKVBias,
61-
AttnO,
62-
FfnNorm,
63-
FfnGateInp,
64-
FfnGateUp,
65-
FfnDown,
66-
}
67-
6856
pub trait WeightLoader {
6957
type Hardware: Hardware;
7058
type Weight<'s>: Deref<Target = [ByteOf<Self::Hardware>]> + 's
@@ -73,14 +61,14 @@ pub trait WeightLoader {
7361

7462
fn load_blk<'a>(
7563
&'a self,
76-
which: BlkWeight,
64+
which: LlamaBlkWeight,
7765
iblk: usize,
7866
queue: &'a QueueOf<Self::Hardware>,
7967
) -> Self::Weight<'a>;
8068

8169
fn load_moe<'a>(
8270
&'a self,
83-
which: BlkWeight,
71+
which: LlamaBlkWeight,
8472
iblk: usize,
8573
iexp: usize,
8674
queue: &'a QueueOf<Self::Hardware>,
@@ -638,7 +626,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
638626
iblk: usize,
639627
queue: &'a QueueOf<W::Hardware>,
640628
) -> Tensor<W::Weight<'a>> {
641-
let w = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue);
629+
let w = self.weights.load_blk(LlamaBlkWeight::AttnNorm, iblk, queue);
642630
self.norm.clone().map(|_| w)
643631
}
644632

@@ -648,7 +636,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
648636
iblk: usize,
649637
queue: &'a QueueOf<W::Hardware>,
650638
) -> Tensor<W::Weight<'a>> {
651-
let w = self.weights.load_blk(BlkWeight::AttnQKV, iblk, queue);
639+
let w = self.weights.load_blk(LlamaBlkWeight::AttnQKV, iblk, queue);
652640
self.attn_qkv.clone().map(|_| w)
653641
}
654642

@@ -658,7 +646,9 @@ impl<W: WeightLoader> WeightDecorator<W> {
658646
iblk: usize,
659647
queue: &'a QueueOf<W::Hardware>,
660648
) -> Tensor<W::Weight<'a>> {
661-
let w = self.weights.load_blk(BlkWeight::AttnQKVBias, iblk, queue);
649+
let w = self
650+
.weights
651+
.load_blk(LlamaBlkWeight::AttnQKVBias, iblk, queue);
662652
self.attn_qkv_bias.clone().map(|_| w)
663653
}
664654

@@ -668,7 +658,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
668658
iblk: usize,
669659
queue: &'a QueueOf<W::Hardware>,
670660
) -> Tensor<W::Weight<'a>> {
671-
let w = self.weights.load_blk(BlkWeight::AttnO, iblk, queue);
661+
let w = self.weights.load_blk(LlamaBlkWeight::AttnO, iblk, queue);
672662
self.attn_o.clone().map(|_| w)
673663
}
674664

@@ -678,7 +668,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
678668
iblk: usize,
679669
queue: &'a QueueOf<W::Hardware>,
680670
) -> Tensor<W::Weight<'a>> {
681-
let w = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue);
671+
let w = self.weights.load_blk(LlamaBlkWeight::FfnNorm, iblk, queue);
682672
self.norm.clone().map(|_| w)
683673
}
684674

@@ -688,7 +678,9 @@ impl<W: WeightLoader> WeightDecorator<W> {
688678
iblk: usize,
689679
queue: &'a QueueOf<W::Hardware>,
690680
) -> Tensor<W::Weight<'a>> {
691-
let w = self.weights.load_blk(BlkWeight::FfnGateInp, iblk, queue);
681+
let w = self
682+
.weights
683+
.load_blk(LlamaBlkWeight::FfnGateInp, iblk, queue);
692684
self.ffn_gate_inp.clone().map(|_| w)
693685
}
694686

@@ -699,7 +691,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
699691
iexp: usize,
700692
queue: &'a QueueOf<W::Hardware>,
701693
) -> Tensor<W::Weight<'a>> {
702-
const WHICH: BlkWeight = BlkWeight::FfnGateUp;
694+
const WHICH: LlamaBlkWeight = LlamaBlkWeight::FfnGateUp;
703695
let w = if self.is_moe {
704696
self.weights.load_moe(WHICH, iblk, iexp, queue)
705697
} else {
@@ -715,7 +707,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
715707
iexp: usize,
716708
queue: &'a QueueOf<W::Hardware>,
717709
) -> Tensor<W::Weight<'a>> {
718-
const WHICH: BlkWeight = BlkWeight::FfnDown;
710+
const WHICH: LlamaBlkWeight = LlamaBlkWeight::FfnDown;
719711
let w = if self.is_moe {
720712
self.weights.load_moe(WHICH, iblk, iexp, queue)
721713
} else {

0 commit comments

Comments
 (0)