Skip to content

Commit 430044c

Browse files
committed
style(gpt2): 整理代码,与 clip 对齐
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent d09838b commit 430044c

File tree

4 files changed

+64
-100
lines changed

4 files changed

+64
-100
lines changed

gguf/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ pub struct GGufModel<'a> {
5757

5858
/// GGuf 张量。
5959
#[derive(Clone, Debug)]
60-
#[allow(missing_docs)]
6160
pub struct GGufTensor<'a> {
6261
pub ty: DigitLayout,
6362
pub shape: Box<[usize]>,

models/gpt2/common/src/compute.rs

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub trait Operators {
2323
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;
2424
type AddRows: AddRows<Self::Hardware>;
2525
type Mlp: Gpt2Mlp<Self::Hardware>;
26+
2627
fn debug<T>(tensor: &Tensor<T>)
2728
where
2829
T: Deref<Target = [ByteOf<Self::Hardware>]>;
@@ -66,6 +67,7 @@ pub struct Gpt2Worker<Ops: Operators, W> {
6667
all_reduce: Ops::AllReduce,
6768
add_rows: Ops::AddRows,
6869
mlp: Ops::Mlp,
70+
pub debug: bool,
6971
}
7072

7173
impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
@@ -81,6 +83,7 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
8183
all_reduce: Ops::AllReduce::new(node),
8284
add_rows: Ops::AddRows::new(processor),
8385
mlp: Ops::Mlp::new(processor),
86+
debug: true,
8487
}
8588
}
8689

@@ -136,7 +139,6 @@ where
136139
idx,
137140
idx_add,
138141
} = args;
139-
140142
let Gpt2Meta {
141143
dt_embd,
142144
nblk,
@@ -145,6 +147,7 @@ where
145147
dh,
146148
..
147149
} = self.meta;
150+
148151
let workspace_size = self.workspace_size(nt, max_seq_len, max_att_len);
149152
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
150153
let queue = queue_alloc.queue();
@@ -161,7 +164,7 @@ where
161164
token_embd = token_embd.merge(0..2).unwrap();
162165
}
163166
let mut x = token_embd;
164-
let x1 = Tensor::new(dt_embd, x.shape());
167+
let x1 = Tensor::new(x.dt(), x.shape());
165168
let (buf, workspace) = workspace.split_at_mut(*x1.get());
166169
let mut x1 = x1.map(|_| buf);
167170
let qkv = Tensor::new(dt_embd, &[nt, (nh + nkvh + nkvh) * dh]);
@@ -177,10 +180,9 @@ where
177180
let mut qkv = qkv.clone().map(|_| buf);
178181
{
179182
let [scale, bias] = self.weights.attn_qkv(iblk, queue);
180-
let cols = bias.shape()[0];
181-
let bias = bias.tile(0, &[1, cols]).broadcast(0, nt);
183+
let bias = bias.broadcast(0, nt);
182184
self.rearrange(&mut qkv, &bias, workspace, queue_alloc)?;
183-
self.mat_mul(&mut qkv, 1., &x1, &scale, 1., workspace, queue_alloc)?;
185+
self.mat_mul(&mut qkv, 1., &x1, &scale, 1., workspace, queue_alloc)?
184186
}
185187
let qkv = qkv.tile(1, &[nh + nkvh + nkvh, dh]);
186188
split!(qkv => q, k, v; [nh, nkvh, nkvh] @ 1);
@@ -215,14 +217,13 @@ where
215217
req.pos,
216218
workspace,
217219
queue_alloc,
218-
)?;
220+
)?
219221
}
220222
}
221223
{
222224
let o = q.map_slice().merge(1..3).unwrap();
223225
let [scale, bias] = self.weights.attn_o(iblk, queue);
224-
let cols = bias.shape()[0];
225-
let bias = bias.tile(0, &[1, cols]).broadcast(0, nt);
226+
let bias = bias.broadcast(0, nt);
226227
self.rearrange(&mut x1, &bias, workspace, queue_alloc)?;
227228
self.mat_mul(&mut x1, 1., &o, &scale, 1., workspace, queue_alloc)?;
228229
}
@@ -506,50 +507,40 @@ where
506507
}
507508

508509
struct WeightDecorator<W> {
509-
attn_norm_w: Tensor<usize>,
510-
attn_norm_b: Tensor<usize>,
510+
pos_embd: Tensor<usize>,
511+
output_weight: Tensor<usize>,
512+
norm: Tensor<usize>,
513+
511514
attn_qkv_w: Tensor<usize>,
512515
attn_qkv_b: Tensor<usize>,
513516
attn_o_w: Tensor<usize>,
514517
attn_o_b: Tensor<usize>,
515518

516-
ffn_norm_w: Tensor<usize>,
517-
ffn_norm_b: Tensor<usize>,
518519
ffn_up_w: Tensor<usize>,
519520
ffn_up_b: Tensor<usize>,
520521
ffn_down_w: Tensor<usize>,
521522
ffn_down_b: Tensor<usize>,
522523

523-
output_norm_w: Tensor<usize>,
524-
output_norm_b: Tensor<usize>,
525-
output_weight: Tensor<usize>,
526-
pos_embd: Tensor<usize>,
527-
528524
weights: W,
529525
}
530526

531527
impl Gpt2Meta {
532528
fn decorator<W>(&self, weights: W) -> WeightDecorator<W> {
533529
use crate::TensorUsage::Computation;
534530
WeightDecorator {
535-
attn_norm_w: self.attn_norm_w(),
536-
attn_norm_b: self.attn_norm_b(),
531+
pos_embd: self.pos_embd(),
532+
output_weight: self.output_weight(),
533+
norm: self.norm(),
534+
537535
attn_qkv_w: self.attn_qkv_w(Computation),
538-
attn_qkv_b: self.attn_qkv_b(),
536+
attn_qkv_b: self.attn_qkv_b(Computation),
539537
attn_o_w: self.attn_o_w(Computation),
540-
attn_o_b: self.attn_o_b(),
538+
attn_o_b: self.attn_o_b(Computation),
541539

542-
ffn_norm_w: self.ffn_norm_w(),
543-
ffn_norm_b: self.ffn_norm_b(),
544540
ffn_up_w: self.ffn_up_w(Computation),
545-
ffn_up_b: self.ffn_up_b(),
541+
ffn_up_b: self.ffn_up_b(Computation),
546542
ffn_down_w: self.ffn_down_w(Computation),
547-
ffn_down_b: self.ffn_down_b(),
548-
549-
output_norm_w: self.output_norm_w(),
550-
output_norm_b: self.output_norm_b(),
551-
output_weight: self.output_weight(),
552-
pos_embd: self.pos_embd(),
543+
ffn_down_b: self.ffn_down_b(Computation),
553544

554545
weights,
555546
}
@@ -563,10 +554,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
563554
queue: &QueueOf<W::Hardware>,
564555
) -> [Tensor<W::Memory<'_>>; 2] {
565556
let [w, b] = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue);
566-
[
567-
self.attn_norm_w.clone().map(|_| w),
568-
self.attn_norm_b.clone().map(|_| b),
569-
]
557+
[self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)]
570558
}
571559

572560
pub fn attn_qkv(
@@ -595,10 +583,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
595583
queue: &QueueOf<W::Hardware>,
596584
) -> [Tensor<W::Memory<'_>>; 2] {
597585
let [w, b] = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue);
598-
[
599-
self.ffn_norm_w.clone().map(|_| w),
600-
self.ffn_norm_b.clone().map(|_| b),
601-
]
586+
[self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)]
602587
}
603588

604589
pub fn ffn_up(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> [Tensor<W::Memory<'_>>; 2] {
@@ -623,10 +608,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
623608

624609
pub fn output_norm(&self, queue: &QueueOf<W::Hardware>) -> [Tensor<W::Memory<'_>>; 2] {
625610
let [w, b] = self.weights.output_norm(queue);
626-
[
627-
self.output_norm_w.clone().map(|_| w),
628-
self.output_norm_b.clone().map(|_| b),
629-
]
611+
[self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)]
630612
}
631613

632614
pub fn output_weight(&self, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {

models/gpt2/common/src/lib.rs

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -84,71 +84,54 @@ impl Gpt2Meta {
8484
pub fn position_embd(&self) -> Tensor<usize> {
8585
self.embd(self.nctx)
8686
}
87-
// ln1_weight
88-
pub fn attn_norm_w(&self) -> Tensor<usize> {
89-
self.norm()
90-
}
91-
// ln1_bias
92-
pub fn attn_norm_b(&self) -> Tensor<usize> {
93-
self.norm()
94-
}
95-
// attn_qkvw
87+
9688
pub fn attn_qkv_w(&self, usage: TensorUsage) -> Tensor<usize> {
97-
self.mat(3 * self.d, self.d, usage)
89+
let &Self { d, .. } = self;
90+
self.mat(3 * d, d, usage)
9891
}
99-
// attn_qkvb
100-
pub fn attn_qkv_b(&self) -> Tensor<usize> {
101-
Tensor::new(self.dt_embd, &[3 * self.d])
92+
93+
pub fn attn_qkv_b(&self, usage: TensorUsage) -> Tensor<usize> {
94+
let &Self { d, .. } = self;
95+
self.mat(3 * d, 1, usage)
10296
}
103-
// attn_projw
97+
10498
pub fn attn_o_w(&self, usage: TensorUsage) -> Tensor<usize> {
105-
self.mat(self.d, self.d, usage)
99+
let &Self { d, .. } = self;
100+
self.mat(d, d, usage)
106101
}
107-
// attn_projb
108-
pub fn attn_o_b(&self) -> Tensor<usize> {
109-
Tensor::new(self.dt_embd, &[self.d])
110-
}
111-
// ln2_weight
112-
pub fn ffn_norm_w(&self) -> Tensor<usize> {
113-
self.norm()
114-
}
115-
// ln2_bias
116-
pub fn ffn_norm_b(&self) -> Tensor<usize> {
117-
self.norm()
102+
103+
pub fn attn_o_b(&self, usage: TensorUsage) -> Tensor<usize> {
104+
let &Self { d, .. } = self;
105+
self.mat(d, 1, usage)
118106
}
119-
// fcw
107+
120108
pub fn ffn_up_w(&self, usage: TensorUsage) -> Tensor<usize> {
121-
self.mat(4 * self.d, self.d, usage)
109+
let &Self { d, di, .. } = self;
110+
self.mat(di, d, usage)
122111
}
123-
// fcb
124-
pub fn ffn_up_b(&self) -> Tensor<usize> {
125-
Tensor::new(self.dt_embd, &[4 * self.d])
112+
113+
pub fn ffn_up_b(&self, _usage: TensorUsage) -> Tensor<usize> {
114+
Tensor::new(self.dt_embd, &[self.di])
126115
}
127-
// fcprojw
116+
128117
pub fn ffn_down_w(&self, usage: TensorUsage) -> Tensor<usize> {
129-
self.mat(self.d, 4 * self.d, usage)
118+
let &Self { d, di, .. } = self;
119+
self.mat(d, di, usage)
130120
}
131-
// fcprojb
132-
pub fn ffn_down_b(&self) -> Tensor<usize> {
121+
122+
pub fn ffn_down_b(&self, _usage: TensorUsage) -> Tensor<usize> {
133123
Tensor::new(self.dt_embd, &[self.d])
134124
}
135-
// lnfw
136-
pub fn output_norm_w(&self) -> Tensor<usize> {
137-
self.norm()
138-
}
139-
// lnfb
140-
pub fn output_norm_b(&self) -> Tensor<usize> {
141-
self.norm()
142-
}
143-
// output.weight
125+
144126
pub fn output_weight(&self) -> Tensor<usize> {
145127
Tensor::new(self.dt_embd, &[self.nvoc, self.d])
146128
}
147129

148-
fn norm(&self) -> Tensor<usize> {
130+
pub fn norm(&self) -> Tensor<usize> {
149131
let &Self { dt_norm, d, .. } = self;
150132
Tensor::new(dt_norm, &[d])
151133
}
134+
152135
pub fn pos_embd(&self) -> Tensor<usize> {
153136
let &Self { nvoc, d, .. } = self;
154137
Tensor::new(self.dt_embd, &[nvoc, d])

models/gpt2/common/src/storage.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@ impl<'a> Storage<&'a [u8]> {
6060
#[rustfmt::skip]
6161
let blocks = (0..meta.nblk)
6262
.map(|i| BlkStorage {
63-
attn_qkv_b: gguf.tensors[&*format!("blk.{i}.attn_qkv.bias" )].data,
64-
attn_qkv_w: gguf.tensors[&*format!("blk.{i}.attn_qkv.weight" )].data,
65-
attn_o_b: gguf.tensors[&*format!("blk.{i}.attn_output.bias" )].data,
66-
attn_o_w: gguf.tensors[&*format!("blk.{i}.attn_output.weight")].data,
67-
attn_norm_b: gguf.tensors[&*format!("blk.{i}.attn_norm.bias" )].data,
68-
attn_norm_w: gguf.tensors[&*format!("blk.{i}.attn_norm.weight" )].data,
63+
attn_norm_w: gguf.tensors[&*format!("blk.{i}.attn_norm.weight" )].data,
64+
attn_norm_b: gguf.tensors[&*format!("blk.{i}.attn_norm.bias" )].data,
65+
attn_qkv_w: gguf.tensors[&*format!("blk.{i}.attn_qkv.weight" )].data,
66+
attn_qkv_b: gguf.tensors[&*format!("blk.{i}.attn_qkv.bias" )].data,
67+
attn_o_w: gguf.tensors[&*format!("blk.{i}.attn_output.weight")].data,
68+
attn_o_b: gguf.tensors[&*format!("blk.{i}.attn_output.bias" )].data,
6969

70-
ffn_up_b: gguf.tensors[&*format!("blk.{i}.ffn_up.bias" )].data,
71-
ffn_up_w: gguf.tensors[&*format!("blk.{i}.ffn_up.weight" )].data,
72-
ffn_down_b: gguf.tensors[&*format!("blk.{i}.ffn_down.bias" )].data,
73-
ffn_down_w: gguf.tensors[&*format!("blk.{i}.ffn_down.weight" )].data,
74-
ffn_norm_b: gguf.tensors[&*format!("blk.{i}.ffn_norm.bias" )].data,
75-
ffn_norm_w: gguf.tensors[&*format!("blk.{i}.ffn_norm.weight" )].data,
70+
ffn_norm_w: gguf.tensors[&*format!("blk.{i}.ffn_norm.weight" )].data,
71+
ffn_norm_b: gguf.tensors[&*format!("blk.{i}.ffn_norm.bias" )].data,
72+
ffn_up_w: gguf.tensors[&*format!("blk.{i}.ffn_up.weight" )].data,
73+
ffn_up_b: gguf.tensors[&*format!("blk.{i}.ffn_up.bias" )].data,
74+
ffn_down_w: gguf.tensors[&*format!("blk.{i}.ffn_down.weight" )].data,
75+
ffn_down_b: gguf.tensors[&*format!("blk.{i}.ffn_down.bias" )].data,
7676
})
7777
.collect();
7878

0 commit comments

Comments
 (0)