Skip to content

Commit 323e7ff

Browse files
committed
feat(llama.cu): 根据维度选 add 算子
1 parent f96f056 commit 323e7ff

File tree

5 files changed

+13
-8
lines changed

5 files changed

+13
-8
lines changed

Cargo.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama.cu/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ edition.workspace = true
55

66
[dependencies]
77
operators = { git = "https://github.com/CearX/operators-rs.git", rev = "8a0d58a", default-features = false, features = ["nvidia-gpu", "common-cpu"] }
8-
nn = { git = "https://github.com/CearX/InfiniNN.git", rev = "74b1bea" }
8+
nn = { git = "https://github.com/CearX/InfiniNN.git", rev = "3ba7418" }
99
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
1010
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }
1111

llama.cu/src/exec/group.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ impl<'ctx> Internal<'ctx> {
453453
fn builder() -> GraphBuilder {
454454
let mut ans = GraphBuilder::default();
455455
ans.register_op("embedding", op::embedding::Embedding)
456-
.register_op("add4d", op::add4d::Add4d)
456+
.register_op("add", op::add::Add)
457457
.register_op("conv", op::conv::Conv)
458458
.register_op("layer-norm", op::normalization::LayerNorm)
459459
.register_op("rms-norm", op::normalization::RmsNorm)

llama.cu/src/exec/step.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,11 @@ impl<'ctx> Handle<'ctx> {
196196
"rms-norm" => launch!(RmsNorm),
197197
"layer-norm" => launch!(LayerNorm),
198198
"linear" => launch!(Linear),
199-
"add4d" => launch!(Add4d),
199+
"add" => match inputs[0].shape().len() {
200+
2 => launch!(Add),
201+
4 => launch!(Add4d),
202+
_ => panic!("add: unsupported shape"),
203+
},
200204
"rope" => launch!(Rope),
201205
"mrope" => launch!(MRope),
202206
"gelu" => launch!(Gelu),

llama.cu/src/op/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use operators::cuda::{Stream, VirByte};
2121

2222
pub mod random_sample;
2323

24+
pub use add::Add;
2425
pub use add4d::Add4d;
2526
#[cfg(nccl)]
2627
pub use all_reduce::AllReduce;

0 commit comments

Comments
 (0)