Skip to content

Commit 2b8c97a

Browse files
committed
todo: 添加 all reduce
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 050c008 commit 2b8c97a

File tree

3 files changed

+43
-13
lines changed

3 files changed

+43
-13
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ test-utils.path = "test-utils"
1616

1717
ggus = { git = "https://github.com/YdrMaster/gguf", rev = "e64d758" }
1818
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "5c6b969" }
19-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "b956b29", default-features = false }
19+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "15cebdd", default-features = false }

models/llama/common-cpu/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use llama::{
44
RandomSample, Tensor, WeightLoader,
55
};
66
use operators::{
7-
common_cpu::{Cpu, ThisThread},
7+
common_cpu::{Cpu, ThisThread, Threads},
88
random_sample::{common_cpu::Operator as CpuOp, KVPair, SampleArgs},
99
ByteOf, QueueOf,
1010
};
@@ -31,7 +31,7 @@ impl Llama {
3131
_storage,
3232
token_embed,
3333
single: LlamaWorker::new(
34-
&Cpu,
34+
todo!(),
3535
meta,
3636
Weights {
3737
blks: blocks,
@@ -116,12 +116,14 @@ macro_rules! op {
116116

117117
impl llama::Operators for Operators {
118118
type Hardware = Cpu;
119+
type TopoNode = Threads;
119120
type RmsNorm = op!(rms_norm);
120121
type MatMul = op!(mat_mul);
121122
type Rope = op!(rope);
122123
type AttnKVCached = op!(attention_kv_cached);
123124
type Mlp = op!(mlp);
124125
type Rearrange = op!(rearrange);
126+
type AllReduce = op!(all_reduce);
125127

126128
fn debug<T>(tensor: &Tensor<T>)
127129
where

models/llama/common/src/compute.rs

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,28 @@
22
use gguf::ggml_quants::digit_layout::{types as ty, DigitLayout};
33
use itertools::izip;
44
use operators::{
5+
all_reduce::{AllReduce, ReduceOp},
56
attention_kv_cached::AttnKVCached,
67
mat_mul::MatMul,
78
mlp::Mlp,
89
rearrange::Rearrange,
910
rms_norm::RmsNorm,
1011
rope::{Rope, Seq},
11-
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, Workspace,
12+
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
1213
};
1314
use std::ops::{Deref, DerefMut};
1415
use tensor::{dt_size, split, Tensor};
1516

1617
pub trait Operators {
1718
type Hardware: Hardware;
19+
type TopoNode: TopoNode<Self::Hardware>;
1820
type RmsNorm: RmsNorm<Self::Hardware>;
1921
type MatMul: MatMul<Self::Hardware>;
2022
type Rope: Rope<Self::Hardware>;
2123
type AttnKVCached: AttnKVCached<Self::Hardware>;
2224
type Mlp: Mlp<Self::Hardware>;
2325
type Rearrange: Rearrange<Self::Hardware>;
26+
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;
2427

2528
fn debug<T>(tensor: &Tensor<T>)
2629
where
@@ -73,19 +76,21 @@ pub struct LlamaWorker<Ops: Operators, W> {
7376
attn_kv_cached: Ops::AttnKVCached,
7477
mlp: Ops::Mlp,
7578
rearrange: Ops::Rearrange,
79+
all_reduce: Ops::AllReduce,
7680
}
7781

7882
impl<Ops: Operators, W> LlamaWorker<Ops, W> {
79-
pub fn new(processor: &Ops::Hardware, meta: LlamaMeta, weights: W) -> Self {
83+
pub fn new(node: &Ops::TopoNode, meta: LlamaMeta, weights: W) -> Self {
8084
Self {
8185
weights: meta.decorator(weights),
8286
meta,
83-
rms_norm: Ops::RmsNorm::new(processor),
84-
mat_mul: Ops::MatMul::new(processor),
85-
rope: Ops::Rope::new(processor),
86-
attn_kv_cached: Ops::AttnKVCached::new(processor),
87-
mlp: Ops::Mlp::new(processor),
88-
rearrange: Ops::Rearrange::new(processor),
87+
rms_norm: Ops::RmsNorm::new(node.processor()),
88+
mat_mul: Ops::MatMul::new(node.processor()),
89+
rope: Ops::Rope::new(node.processor()),
90+
attn_kv_cached: Ops::AttnKVCached::new(node.processor()),
91+
mlp: Ops::Mlp::new(node.processor()),
92+
rearrange: Ops::Rearrange::new(node.processor()),
93+
all_reduce: Ops::AllReduce::new(node),
8994
}
9095
}
9196

@@ -240,7 +245,7 @@ where
240245
self.mat_mul(&mut x, 1., &x1, &w, 1., workspace, queue_alloc)?;
241246

242247
if distribute > 1 {
243-
todo!("all reduce")
248+
self.all_reduce(&mut x, workspace, queue_alloc)?;
244249
}
245250

246251
let w = self.weights.ffn_norm(iblk, queue);
@@ -250,7 +255,7 @@ where
250255
self.mlp(&mut x, &x1, iblk, mlp_alpha, true, workspace, queue_alloc)?;
251256

252257
if distribute > 1 {
253-
todo!("all reduce")
258+
self.all_reduce(&mut x, workspace, queue_alloc)?;
254259
}
255260
}
256261

@@ -483,6 +488,29 @@ where
483488
queue_alloc,
484489
)
485490
}
491+
492+
fn all_reduce<X, QA>(
493+
&self,
494+
x: &mut Tensor<X>,
495+
workspace: &mut [ByteOf<Ops::Hardware>],
496+
queue_alloc: &QA,
497+
) -> Result<(), LaunchError>
498+
where
499+
X: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
500+
QA: QueueAlloc<Hardware = Ops::Hardware>,
501+
{
502+
self.all_reduce.launch(
503+
&operators::all_reduce::Args {
504+
dst_layout: x.layout(),
505+
dst_base: x.base_mut(),
506+
src_layout: x.layout(),
507+
src_base: x.base(),
508+
op: ReduceOp::Sum,
509+
},
510+
workspace,
511+
queue_alloc,
512+
)
513+
}
486514
}
487515

488516
struct WeightDecorator<W> {

0 commit comments

Comments
 (0)