Skip to content

Commit 9785400

Browse files
committed
添加scale算子
1 parent 166d2ea commit 9785400

File tree

3 files changed

+37
-13
lines changed

3 files changed

+37
-13
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ itertools = "0.13"
4141
env_logger = "0.11"
4242
build-script-cfg = "0.0"
4343

44-
operators = { git = "https://github.com/onenewcode/operators-rs", branch = "dev", default-features = false }
44+
operators = { git = "https://github.com/onenewcode/operators-rs", rev = "f4a83f7", default-features = false }
4545

4646
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4747
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/minicpm3/common-cpu/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ where
4141
type MatMul = op!(mat_mul);
4242
type Swiglu = op!(swiglu);
4343
type Rearrange = op!(rearrange);
44+
type Scale = op!(scale);
4445
type AttnKVCached = op!(attention_kv_cached);
4546
type AllReduce = R;
4647

models/minicpm3/common/src/compute.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use gguf::ggml_quants::digit_layout::types as ty;
33
use gguf::ggml_quants::digit_layout::DigitLayout;
44
use half::f16;
55
use itertools::{izip, Itertools};
6+
use operators::scale;
7+
use operators::scale::Scale;
68
use operators::{
79
add::{self, Add},
810
all_reduce::{self, AllReduce, ReduceOp},
@@ -30,6 +32,7 @@ pub trait Operators {
3032
type Add: Add<Self::Hardware>;
3133
type MatMul: MatMul<Self::Hardware>;
3234
type Swiglu: Swiglu<Self::Hardware>;
35+
type Scale:Scale<Self::Hardware>;
3336
type Rearrange: Rearrange<Self::Hardware>;
3437
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;
3538

@@ -82,6 +85,7 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
8285
rope: Ops::Rope,
8386
rms_norm: Ops::RmsNorm,
8487
mat_mul: Ops::MatMul,
88+
scale:Ops::Scale,
8589
swiglu: Ops::Swiglu,
8690
rearrange: Ops::Rearrange,
8791
all_reduce: Ops::AllReduce,
@@ -98,6 +102,7 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
98102
rope: Ops::Rope::new(processor),
99103
rms_norm: Ops::RmsNorm::new(processor),
100104
mat_mul: Ops::MatMul::new(processor),
105+
scale: Ops::Scale::new(processor),
101106
swiglu: Ops::Swiglu::new(processor),
102107
rearrange: Ops::Rearrange::new(processor),
103108
add: Ops::Add::new(processor),
@@ -154,18 +159,7 @@ where
154159
let scale_depth = 1.4f32;
155160
// 残差连接时权重缩放
156161
let s = scale_depth / (nblk as f32).sqrt();
157-
fn ggml_scale(embd: *mut f16, s: f16, l: usize) {
158-
if l == 0 {
159-
return;
160-
} // 如果长度为 0,则无需进行任何操作
161-
162-
unsafe {
163-
let slice = std::slice::from_raw_parts_mut(embd, l);
164-
slice.iter_mut().for_each(|x| *x *= s);
165-
}
166-
}
167-
168-
ggml_scale(x.base_mut().cast::<f16>(), f16::from_f32(scale_emb), d);
162+
169163

170164
let dnope = dk - dh;
171165
let tensor = |shape: &[usize]| Tensor::new(dt_embd, shape);
@@ -184,6 +178,10 @@ where
184178
let mut attn = attn.map(|_| buf);
185179

186180
let queue = queue_alloc.queue();
181+
// 缩放
182+
let inplace=unsafe {
183+
x.map_slice_static()};
184+
self.scale(&mut x, &inplace, scale_emb, workspace, queue_alloc)?;
187185
for iblk in 0..nblk {
188186
// norm
189187
let w = self.weights.attn_norm(iblk, queue);
@@ -619,6 +617,31 @@ where
619617
queue_alloc,
620618
)
621619
}
620+
fn scale<C, A, QA>(
621+
&self,
622+
c: &mut Tensor<C>,
623+
a: &Tensor<A>,
624+
s:f32,
625+
workspace: &mut [ByteOf<Ops::Hardware>],
626+
queue_alloc: &QA,
627+
) -> Result<(), LaunchError>
628+
where
629+
C: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
630+
A: Deref<Target = [ByteOf<Ops::Hardware>]>,
631+
QA: QueueAlloc<Hardware = Ops::Hardware>,
632+
{
633+
self.scale.launch(
634+
&scale::Args {
635+
c_layout: c.layout(),
636+
c_base: c.base_mut(),
637+
a_layout: a.layout(),
638+
a_base: a.base(),
639+
s,
640+
},
641+
workspace,
642+
queue_alloc,
643+
)
644+
}
622645
fn all_reduce<X, QA>(
623646
&self,
624647
x: &mut Tensor<X>,

0 commit comments

Comments
 (0)