Skip to content

Commit 086307c

Browse files
authored
Merge pull request #13 from InfiniTensor/dev-operators
move random search to operators-rs
2 parents 7847e93 + 0b19878 commit 086307c

File tree

37 files changed

+643
-772
lines changed

37 files changed

+643
-772
lines changed

.github/workflows/build.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ jobs:
2626
- name: Checkout code
2727
uses: actions/checkout@v4
2828

29+
- name: cuda-toolkit
30+
uses: Jimver/cuda-toolkit@v0.2.16
31+
with:
32+
method: 'network'
33+
2934
- name: Check format
3035
run: cargo fmt --check
3136

Cargo.lock

Lines changed: 32 additions & 40 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ members = [
44
"tensor",
55
"tokenizer",
66
"causal-lm",
7-
"sample",
87
"service",
98
"web-api",
109
"xtask",
@@ -35,7 +34,6 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
3534
digit-layout = "0.0"
3635
build-script-cfg = "0.0"
3736

38-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "04e71d5", default-features = false }
39-
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "343b0e0" }
40-
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "343b0e0" }
37+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "5a88159", default-features = false }
38+
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "fb088b6" }
4139
search-neuware-tools = "0.0"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
![GitHub contributors](https://img.shields.io/github/contributors/InfiniTensor/transformer-rs)
88
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/InfiniTensor/transformer-rs)
99

10+
[**使用指南**](/docs/user-guide/doc.md)
11+
1012
[YdrMaster/llama2.rs](https://github.com/YdrMaster/llama2.rs) 发展来的手写 transformer 模型项目。
1113

1214
## 使用

causal-lm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
99
[dependencies]
1010
common = { path = "../common" }
1111
tensor = { path = "../tensor" }
12-
sample = { path = "../sample" }
1312
digit-layout.workspace = true
13+
operators.workspace = true

causal-lm/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use std::path::Path;
1010
use tensor::{udim, Tensor};
1111

1212
pub use decoding::DecodingMeta;
13+
pub use operators::random_sample::SampleArgs;
1314
pub use query_context::QueryContext;
14-
pub use sample::SampleArgs;
1515

1616
/// 从文件系统加载的模型。
1717
pub trait Model: Sized {
@@ -142,7 +142,7 @@ where
142142

143143
let args = [SampleMeta {
144144
num_decode: 1,
145-
args: SampleArgs::default(),
145+
args: SampleArgs::ARG_MAX,
146146
}];
147147
let tokens = CausalLM::sample(&model, args, logits);
148148

devices/common-cpu/src/lib.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@ macro_rules! slice {
77

88
mod gather;
99

10-
use common::utok;
10+
use common::{f16, utok};
1111
use common_devices::{Operators, SliceOn};
12+
use digit_layout::types::F16;
1213
use operators::{
13-
fuesd_softmax::common_cpu as softmax, mat_mul::common_cpu as mat_mul,
14-
reform::common_cpu as reform, rms_norm::common_cpu as rms_norm, rope::common_cpu as rope,
15-
swiglu::common_cpu as swiglu, Operator, QueueOf,
14+
fuesd_softmax::common_cpu as softmax,
15+
mat_mul::common_cpu as mat_mul,
16+
random_sample::{common_cpu as random_sample, Args, KVPair, SampleArgs},
17+
reform::common_cpu as reform,
18+
rms_norm::common_cpu as rms_norm,
19+
rope::common_cpu as rope,
20+
swiglu::common_cpu as swiglu,
21+
Operator, QueueOf,
1622
};
1723
use std::ops::{Deref, DerefMut};
1824
use tensor::Tensor;
@@ -29,6 +35,23 @@ pub struct CpuKernels {
2935
rope: rope::Operator,
3036
softmax: softmax::Operator,
3137
swiglu: swiglu::Operator,
38+
sample: random_sample::Operator,
39+
}
40+
41+
impl CpuKernels {
42+
pub fn sample(&self, temperature: f32, top_p: f32, top_k: usize, logits: &[f16]) -> utok {
43+
let mut kv_pair = KVPair::new(0, f16::ZERO);
44+
let mut args = Args::<Cpu>::new(F16, logits.len());
45+
args.kv_pair_base = &mut kv_pair as *mut _ as _;
46+
args.data_base = logits.as_ptr() as _;
47+
args.detail = SampleArgs {
48+
temperature,
49+
top_p,
50+
top_k,
51+
};
52+
self.sample.launch(&args, &ThisThread).unwrap();
53+
kv_pair.idx() as _
54+
}
3255
}
3356

3457
impl Default for CpuKernels {
@@ -40,6 +63,7 @@ impl Default for CpuKernels {
4063
rope: rope::Operator::new(&Cpu),
4164
softmax: softmax::Operator::new(&Cpu),
4265
swiglu: swiglu::Operator::new(&Cpu),
66+
sample: random_sample::Operator::new(&Cpu),
4367
}
4468
}
4569
}

devices/nvidia-gpu/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
1010
common = { path = "../../common" }
1111
common-devices = { path = "../common" }
1212
tensor = { path = "../../tensor" }
13-
sample = { path = "../../sample" }
1413
rand = "0.8"
1514
operators = { workspace = true, features = ["nvidia-gpu"] }
1615
digit-layout.workspace = true
1716

1817
[build-dependencies]
1918
build-script-cfg.workspace = true
2019
search-cuda-tools.workspace = true
21-
cc = "1.0"

devices/nvidia-gpu/build.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
fn main() {
1+
fn main() {
22
use build_script_cfg::Cfg;
3-
use search_cuda_tools::find_cuda_root;
3+
use search_cuda_tools::{find_cuda_root, find_nccl_root};
44

55
let cuda = Cfg::new("detected_cuda");
6+
let nccl = Cfg::new("detected_nccl");
67
if find_cuda_root().is_some() {
78
cuda.define();
8-
println!("cargo:rerun-if-changed=src/sample.cu");
9-
cc::Build::new()
10-
.cuda(true)
11-
.flag("-gencode")
12-
.flag("arch=compute_80,code=sm_80")
13-
.flag("-allow-unsupported-compiler")
14-
.file("src/sample.cu")
15-
.compile("sample");
9+
if find_nccl_root().is_some() {
10+
nccl.define();
11+
}
1612
}
1713
}

0 commit comments

Comments
 (0)