Skip to content

Commit f5833fb

Browse files
committed
fix(infini): 为统一算子库修改
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 4afc911 commit f5833fb

File tree

4 files changed

+20
-16
lines changed

4 files changed

+20
-16
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ regex = "1.11"
3030
env_logger = "0.11"
3131
build-script-cfg = "0.0"
3232

33-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "360b664", default-features = false }
33+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "d22d2bc", default-features = false }
3434
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
3535
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "0e57976" }
3636
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "5b9dbd9" }

models/llama/infini/src/infer.rs

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use gguf::{ggml_quants::digit_layout::types, GGufModel};
33
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
44
use operators::{
55
infini::Device,
6-
infini_rt::{self, DeviceType::DEVICE_CPU},
6+
infini_rt,
77
random_sample::{KVPair, SampleArgs},
88
TopoNode,
99
};
@@ -59,7 +59,7 @@ fn test_infer() {
5959
let count = devices.len();
6060
println!("distribution: {devices:?}");
6161

62-
infini_rt::init(DEVICE_CPU);
62+
infini_rt::init(infini_rt::DEVICE_CPU);
6363
let (seeds, senders) = WorkerSeed::new(devices.into_iter().map(|_| Device::cpu()).collect());
6464
thread::scope(|s| {
6565
let _workers = zip(lens, seeds)
@@ -90,8 +90,6 @@ fn test_infer() {
9090

9191
let sample = RandomSample::new(&node);
9292
let indices = RandomSample::build_indices(model.meta.nvoc, &stream);
93-
let mut pair = KVPair::new(0, f16::ZERO);
94-
let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::<u8>(size));
9593

9694
for task in tasks {
9795
let Task {
@@ -127,6 +125,11 @@ fn test_infer() {
127125
)
128126
.unwrap();
129127
if id == 0 {
128+
// NOTICE 目前 random sample 完全是 CPU 上执行的,没必要再拷贝了
129+
let mut pair = KVPair::new(0, f16::ZERO);
130+
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
131+
from_raw_parts_mut(&mut pair as *mut _ as _, size_of_val(&pair))
132+
});
130133
sample
131134
.launch(
132135
&mut pairs,
@@ -137,18 +140,7 @@ fn test_infer() {
137140
&stream,
138141
)
139142
.unwrap();
140-
141143
stream.synchronize();
142-
device.memcpy_d2h(
143-
unsafe {
144-
from_raw_parts_mut(
145-
&mut pair as *mut _ as *mut u8,
146-
pairs.get().len(),
147-
)
148-
},
149-
pairs.get(),
150-
);
151-
152144
next.send(pair.idx() as _).unwrap()
153145
}
154146
}

models/llama/infini/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ where
5353
let tensor = tensor.as_ref().map(|s| {
5454
let mut host = vec![0u8; s.len()];
5555
queue.get_device().memcpy_d2h(&mut host, s);
56+
queue.synchronize();
5657
host
5758
});
5859
println!("{tensor}");

tensor/src/fmt.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ impl DataFmt for u32 {
5656
}
5757
}
5858

59+
impl DataFmt for u64 {
60+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61+
if *self == 0 {
62+
write!(f, " ________")
63+
} else {
64+
write!(f, "{self:>6}")
65+
}
66+
}
67+
}
68+
5969
impl<Physical: Deref<Target = [u8]>> fmt::Display for Tensor<Physical> {
6070
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
6171
match self.dt {
@@ -64,6 +74,7 @@ impl<Physical: Deref<Target = [u8]>> fmt::Display for Tensor<Physical> {
6474
primitive::F32 => self.map_slice().write_tensor::<f32>(&mut vec![], f),
6575
primitive::F64 => self.map_slice().write_tensor::<f64>(&mut vec![], f),
6676
primitive::U32 => self.map_slice().write_tensor::<u32>(&mut vec![], f),
77+
primitive::U64 => self.map_slice().write_tensor::<u64>(&mut vec![], f),
6778
_ => todo!(),
6879
}
6980
}

0 commit comments

Comments
 (0)