Skip to content

Commit ab671b9

Browse files
committed
style: code cleanup
Signed-off-by: YdrMaster <[email protected]>
1 parent 72fc503 commit ab671b9

File tree

3 files changed

+21
-33
lines changed

3 files changed

+21
-33
lines changed

models/gpt2/common/src/compute.rs

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,6 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
9696
pub const fn meta(&self) -> &Gpt2Meta {
9797
&self.meta
9898
}
99-
100-
pub fn workspace_size(&self, nt: usize, max_seq_len: usize, max_att_len: usize) -> usize {
101-
let Gpt2Meta {
102-
nh, nkvh, dh, di, ..
103-
} = self.meta;
104-
105-
let embd = self.meta.embd(nt);
106-
let dt = embd.dt();
107-
let embd = embd.take();
108-
109-
let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take();
110-
let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take();
111-
let att = Tensor::new(dt, &[nh, max_seq_len, max_att_len]).take();
112-
113-
let up = Tensor::new(dt, &[nt, di]).take();
114-
embd + (qkv + q + att).max(up)
115-
}
11699
}
117100

118101
impl<Ops, W> Gpt2Worker<Ops, W>
@@ -153,13 +136,18 @@ where
153136
self.add_rows(&mut embd, &pos_embd, &idx, workspace, queue_alloc)?
154137
}
155138

156-
let nt = embd.shape()[0];
157139
let mut x = embd;
158-
let x1 = Tensor::new(x.dt(), x.shape());
159-
let qkv = Tensor::new(x.dt(), &[nt, (nh + nkvh + nkvh) * dh]);
160-
let up = Tensor::new(x.dt(), &[nt, di]);
140+
let nt = x.shape()[0];
141+
142+
let tensor = |shape: &[usize]| Tensor::new(x.dt(), shape);
143+
let x1 = tensor(x.shape());
144+
let qkv = tensor(&[nt, (nh + nkvh + nkvh) * dh]);
145+
let q = tensor(&[max_seq_len, nh, dh]).take();
146+
let att = tensor(&[nh, max_seq_len, max_att_len]).take();
147+
let up = tensor(&[nt, di]);
148+
149+
let workspace_size = *x1.get() + (*qkv.get() + q + att).max(*up.get());
161150

162-
let workspace_size = self.workspace_size(nt, max_seq_len, max_att_len);
163151
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
164152
let (buf, workspace) = workspace.split_at_mut(*x1.get());
165153
let mut x1 = x1.map(|_| buf);
@@ -253,9 +241,9 @@ where
253241
if src != dst {
254242
let src = unsafe { x.map_slice_static() }.index(0, src);
255243
let mut dst = x.map_slice_mut().index(0, dst);
256-
self.rearrange(&mut dst, &src, workspace, queue_alloc)?;
244+
self.rearrange(&mut dst, &src, workspace, queue_alloc)?
257245
}
258-
dst += 1;
246+
dst += 1
259247
}
260248
}
261249
assert_eq!(dst, logits.shape()[0]);

models/gpt2/common/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use common::Distribution;
66
use gguf::ggml_quants::digit_layout::DigitLayout;
77

88
pub use args::{Args as GPT2Args, Request as GPT2Request};
9-
pub use common::Contiguous;
109
pub use compute::{BlkWeight, Gpt2Worker, Operators, WeightLoader};
1110
pub use storage::{BlkStorage as GPT2BlkStorage, Storage as GPT2Storage};
1211
pub use tensor::{RandomSample, Tensor};
@@ -16,6 +15,7 @@ pub mod ext {
1615
ggml_quants,
1716
};
1817
}
18+
1919
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
2020
pub enum GPT2BlkWeight {
2121
AttnQkvB,
@@ -31,6 +31,7 @@ pub enum GPT2BlkWeight {
3131
FfnNormB,
3232
FfnNormW,
3333
}
34+
3435
#[derive(Clone, Debug)]
3536
pub struct Gpt2Meta {
3637
pub dt_embd: DigitLayout,
@@ -72,6 +73,7 @@ impl Gpt2Meta {
7273
..self.clone()
7374
}
7475
}
76+
7577
pub fn blk(&self) -> GPT2BlkStorage<usize> {
7678
use TensorUsage::Storage as TensorMem;
7779
GPT2BlkStorage {

models/llama/common/src/compute.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ where
172172
let att = tensor(&[nh, max_seq_len, max_att_len]).take();
173173
let gate_up = tensor(&[if self.meta.is_moe() { 1 } else { nt }, di * 2]);
174174
let routes = tensor(&[nt, nexp]);
175+
let mut routes_host = routes.clone().map(Blob::new).take();
175176

176177
let workspace_size = *x1.get()
177178
+ (*qkv.get() + q + att)
@@ -204,8 +205,7 @@ where
204205
for iblk in 0..nblk {
205206
{
206207
let w = self.weights.attn_norm(iblk, queue);
207-
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
208-
drop(w);
208+
self.rms_norm(&mut x1, &x, w, workspace, queue_alloc)?;
209209

210210
let (buf, workspace) = workspace.split_at_mut(*qkv.get());
211211
let mut qkv = qkv.clone().map(|_| buf);
@@ -272,8 +272,7 @@ where
272272
self.all_reduce(&mut x, workspace, queue_alloc)?;
273273

274274
let w = self.weights.ffn_norm(iblk, queue);
275-
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
276-
drop(w);
275+
self.rms_norm(&mut x1, &x, w, workspace, queue_alloc)?;
277276

278277
if !self.meta.is_moe() {
279278
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
@@ -291,7 +290,6 @@ where
291290
let residual = if self.id == 0 { 1. } else { 0. };
292291
self.mat_mul(&mut x, residual, &gate, &w, 1., workspace, queue_alloc)?
293292
} else {
294-
let mut routes_host = routes.clone().map(Blob::new).take();
295293
// gate_inp
296294
{
297295
let (buf, workspace) = workspace.split_at_mut(*routes.get());
@@ -332,7 +330,7 @@ where
332330
}
333331
}
334332
}
335-
self.all_reduce(&mut x, workspace, queue_alloc)?;
333+
self.all_reduce(&mut x, workspace, queue_alloc)?
336334
}
337335
if logits.shape()[0] == 0 {
338336
return Ok(());
@@ -359,7 +357,7 @@ where
359357
{
360358
let inplace = unsafe { x.map_slice_static() };
361359
let w = self.weights.output_norm(queue);
362-
self.rms_norm(&mut x, &inplace, &w, workspace, queue_alloc)?
360+
self.rms_norm(&mut x, &inplace, w, workspace, queue_alloc)?
363361
}
364362
let w = self.weights.output(queue);
365363
self.mat_mul(&mut logits, 0., &x, &w, 1., workspace, queue_alloc)
@@ -397,7 +395,7 @@ where
397395
&self,
398396
y: &mut Tensor<Y>,
399397
x: &Tensor<X>,
400-
w: &Tensor<W_>,
398+
w: Tensor<W_>,
401399
workspace: &mut [ByteOf<Ops::Hardware>],
402400
queue_alloc: &QA,
403401
) -> Result<(), LaunchError>

0 commit comments

Comments
 (0)