Skip to content

Commit c7ef588

Browse files
committed
todo: 创建 llama-nv
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent cd61335 commit c7ef588

File tree

5 files changed

+111
-17
lines changed

5 files changed

+111
-17
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ members = [
44
"tensor",
55
"models/llama/common",
66
"models/llama/common-cpu",
7+
"models/llama/nvidia-gpu",
78
"test-utils",
89
]
910
resolver = "2"
@@ -14,6 +15,6 @@ tensor.path = "tensor"
1415
causal-lm.path = "causal-lm"
1516
test-utils.path = "test-utils"
1617

17-
ggus = { git = "https://github.com/YdrMaster/gguf", rev = "e64d758" }
18+
ggus = { git = "https://github.com/YdrMaster/gguf", rev = "c676bcc" }
1819
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "5c6b969" }
1920
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "64419f0", default-features = false }

models/llama/common-cpu/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,15 @@ struct Weights {
142142

143143
impl WeightLoader for Weights {
144144
type Hardware = Cpu;
145-
type Memory = &'static [u8];
145+
type Memory<'s> = &'s [u8];
146146

147147
#[inline]
148148
fn load_blk(
149149
&self,
150150
which: BlkWeight,
151151
iblk: usize,
152152
_queue: &QueueOf<Self::Hardware>,
153-
) -> Self::Memory {
153+
) -> Self::Memory<'_> {
154154
let blk = &self.blks[iblk];
155155
match which {
156156
BlkWeight::AttnNorm => blk.attn_norm,
@@ -163,12 +163,12 @@ impl WeightLoader for Weights {
163163
}
164164

165165
#[inline]
166-
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Memory {
166+
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Memory<'_> {
167167
self.output_norm
168168
}
169169

170170
#[inline]
171-
fn output(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Memory {
171+
fn output(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Memory<'_> {
172172
self.output
173173
}
174174
}

models/llama/common/src/compute.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,19 @@ pub enum BlkWeight {
5454

5555
pub trait WeightLoader {
5656
type Hardware: Hardware;
57-
type Memory: Deref<Target = [ByteOf<Self::Hardware>]>;
57+
type Memory<'s>: Deref<Target = [ByteOf<Self::Hardware>]> + 's
58+
where
59+
Self: 's;
5860

5961
fn load_blk(
6062
&self,
6163
which: BlkWeight,
6264
iblk: usize,
6365
queue: &QueueOf<Self::Hardware>,
64-
) -> Self::Memory;
66+
) -> Self::Memory<'_>;
6567

66-
fn output_norm(&self, queue: &QueueOf<Self::Hardware>) -> Self::Memory;
67-
fn output(&self, queue: &QueueOf<Self::Hardware>) -> Self::Memory;
68+
fn output_norm(&self, queue: &QueueOf<Self::Hardware>) -> Self::Memory<'_>;
69+
fn output(&self, queue: &QueueOf<Self::Hardware>) -> Self::Memory<'_>;
6870
}
6971

7072
pub struct LlamaWorker<Ops: Operators, W> {
@@ -544,60 +546,60 @@ impl LlamaMeta {
544546

545547
impl<W: WeightLoader> WeightDecorator<W> {
546548
#[inline]
547-
pub fn attn_norm(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
549+
pub fn attn_norm(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
548550
combine(
549551
&self.attn_norm,
550552
self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue),
551553
)
552554
}
553555

554556
#[inline]
555-
pub fn attn_qkv(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
557+
pub fn attn_qkv(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
556558
combine(
557559
&self.attn_qkv,
558560
self.weights.load_blk(BlkWeight::AttnQKV, iblk, queue),
559561
)
560562
}
561563

562564
#[inline]
563-
pub fn attn_o(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
565+
pub fn attn_o(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
564566
combine(
565567
&self.attn_o,
566568
self.weights.load_blk(BlkWeight::AttnO, iblk, queue),
567569
)
568570
}
569571

570572
#[inline]
571-
pub fn ffn_norm(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
573+
pub fn ffn_norm(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
572574
combine(
573575
&self.ffn_norm,
574576
self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue),
575577
)
576578
}
577579

578580
#[inline]
579-
pub fn ffn_gate_up(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
581+
pub fn ffn_gate_up(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
580582
combine(
581583
&self.ffn_gate_up,
582584
self.weights.load_blk(BlkWeight::FfnGateUp, iblk, queue),
583585
)
584586
}
585587

586588
#[inline]
587-
pub fn ffn_down(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
589+
pub fn ffn_down(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
588590
combine(
589591
&self.ffn_down,
590592
self.weights.load_blk(BlkWeight::FfnDown, iblk, queue),
591593
)
592594
}
593595

594596
#[inline]
595-
pub fn output_norm(&self, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
597+
pub fn output_norm(&self, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
596598
combine(&self.output_norm, self.weights.output_norm(queue))
597599
}
598600

599601
#[inline]
600-
pub fn output(&self, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory> {
602+
pub fn output(&self, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
601603
combine(&self.output, self.weights.output(queue))
602604
}
603605
}

models/llama/nvidia-gpu/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "llama-nv"
3+
version = "0.0.0"
4+
edition = "2021"
5+
authors = ["YdrMaster <ydrml@hotmail.com>"]
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
llama.path = "../common"
11+
operators = { workspace = true, features = ["nvidia-gpu"] }
12+
13+
[dev-dependencies]
14+
test-utils.workspace = true
15+
gguf.workspace = true

models/llama/nvidia-gpu/src/lib.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use llama::{ext::Mmap, LlamaStorage, Tensor, WeightLoader};
2+
use operators::{
3+
all_reduce::NonAllReduce,
4+
cuda::{memcpy_d2h, DevByte},
5+
nvidia_gpu::Gpu,
6+
ByteOf,
7+
};
8+
use std::ops::Deref;
9+
10+
pub struct Llama {}
11+
12+
impl Llama {
13+
pub fn new(_storage: Box<[Mmap]>, _model: LlamaStorage<&'static [u8]>) -> Self {
14+
Self {}
15+
}
16+
17+
pub fn infer(&mut self, input: &[u32], cache: &mut [u8], pos: usize) -> u32 {
18+
todo!()
19+
}
20+
}
21+
22+
struct Operators;
23+
24+
macro_rules! op {
25+
($name:ident) => {
26+
operators::$name::nvidia_gpu::Operator
27+
};
28+
}
29+
30+
impl llama::Operators for Operators {
31+
type Hardware = Gpu;
32+
type TopoNode = Gpu;
33+
type RmsNorm = op!(rms_norm);
34+
type MatMul = op!(mat_mul);
35+
type Rope = op!(rope);
36+
type AttnKVCached = op!(attention_kv_cached);
37+
type Mlp = op!(mlp);
38+
type Rearrange = op!(rearrange);
39+
type AllReduce = NonAllReduce<Gpu>;
40+
41+
fn debug<T>(tensor: &Tensor<T>)
42+
where
43+
T: Deref<Target = [ByteOf<Self::Hardware>]>,
44+
{
45+
let tensor = tensor.as_ref().map(|mem| {
46+
let mut buf = vec![0u8; mem.len()];
47+
memcpy_d2h(&mut buf, mem);
48+
buf
49+
});
50+
println!("{tensor}");
51+
}
52+
}
53+
54+
struct Weights {}
55+
56+
impl WeightLoader for Weights {
57+
type Hardware = Gpu;
58+
type Memory<'s> = &'s [DevByte];
59+
60+
fn load_blk(
61+
&self,
62+
which: llama::BlkWeight,
63+
iblk: usize,
64+
queue: &operators::QueueOf<Self::Hardware>,
65+
) -> Self::Memory<'_> {
66+
todo!()
67+
}
68+
69+
fn output_norm(&self, queue: &operators::QueueOf<Self::Hardware>) -> Self::Memory<'_> {
70+
todo!()
71+
}
72+
73+
fn output(&self, queue: &operators::QueueOf<Self::Hardware>) -> Self::Memory<'_> {
74+
todo!()
75+
}
76+
}

0 commit comments

Comments
 (0)