Skip to content

Commit 01bea21

Browse files
Add dummy dtypes (#3195)
* Add dummy i32/i16/f6e2m3/f6e3m2/f4/f8e8m0 dtypes * Metal fixes * Fix candle-onnx build * Apply review comments * Residual fixes * Apply review comments * Apply review comments * Revert some things * Free more space --------- Co-authored-by: ivarflakstad <[email protected]>
1 parent b801ef6 commit 01bea21

File tree

26 files changed

+1697
-227
lines changed

26 files changed

+1697
-227
lines changed

.github/workflows/rust-ci.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@ jobs:
3131
matrix:
3232
os: [ubuntu-latest, windows-latest, macOS-latest]
3333
steps:
34-
- name: Delete huge unnecessary tools folder
34+
- name: Free disk space (Linux)
3535
if: runner.os == 'Linux'
36-
run: rm -rf /opt/hostedtoolcache
36+
run: |
37+
sudo rm -rf /opt/hostedtoolcache
38+
sudo rm -rf /usr/share/dotnet
39+
sudo rm -rf /usr/local/lib/android
40+
sudo rm -rf /opt/ghc
41+
df -h
3742
- uses: actions/checkout@v5
3843
- uses: actions/setup-python@v6
3944
with:
@@ -48,7 +53,7 @@ jobs:
4853
- name: Run tests (with lld on Linux)
4954
if: runner.os == 'Linux'
5055
env:
51-
RUSTFLAGS: "-Clinker-features=-lld"
56+
RUSTFLAGS: "-C link-arg=-fuse-ld=lld"
5257
run: cargo test --workspace
5358
- name: Run tests (Windows & macOS)
5459
if: runner.os != 'Linux'

candle-core/src/backend.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
158158
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
159159

160160
fn set_seed(&self, _: u64) -> Result<()>;
161+
fn get_current_seed(&self) -> Result<u64>;
161162

162163
/// Synchronize should block until all the operations on the device are completed.
163164
fn synchronize(&self) -> Result<()>;

candle-core/src/convert.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! Implement conversion traits for tensors
22
use crate::{DType, Device, Error, Tensor, WithDType};
3-
use float8::F8E4M3;
43
use half::{bf16, f16, slice::HalfFloatSliceExt};
54
use std::convert::TryFrom;
65

@@ -94,6 +93,8 @@ from_tensor!(f32);
9493
from_tensor!(f16);
9594
from_tensor!(bf16);
9695
from_tensor!(i64);
96+
from_tensor!(i32);
97+
from_tensor!(i16);
9798
from_tensor!(u32);
9899
from_tensor!(u8);
99100

@@ -131,6 +132,16 @@ impl Tensor {
131132
f.write_u32::<LittleEndian>(v)?
132133
}
133134
}
135+
DType::I16 => {
136+
for v in vs.to_vec1::<i16>()? {
137+
f.write_i16::<LittleEndian>(v)?
138+
}
139+
}
140+
DType::I32 => {
141+
for v in vs.to_vec1::<i32>()? {
142+
f.write_i32::<LittleEndian>(v)?
143+
}
144+
}
134145
DType::I64 => {
135146
for v in vs.to_vec1::<i64>()? {
136147
f.write_i64::<LittleEndian>(v)?
@@ -141,10 +152,14 @@ impl Tensor {
141152
f.write_all(&vs)?;
142153
}
143154
DType::F8E4M3 => {
144-
for v in vs.to_vec1::<F8E4M3>()? {
155+
let vs = vs.to_vec1::<float8::F8E4M3>()?;
156+
for v in vs {
145157
f.write_u8(v.to_bits())?
146158
}
147159
}
160+
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
161+
return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt())
162+
}
148163
}
149164
Ok(())
150165
}

candle-core/src/cpu/kernels.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,28 @@ impl VecOps for u32 {
151151
<Self as Ord>::max(self, other)
152152
}
153153
}
154+
impl VecOps for i16 {
155+
#[inline(always)]
156+
fn min(self, other: Self) -> Self {
157+
<Self as Ord>::min(self, other)
158+
}
159+
160+
#[inline(always)]
161+
fn max(self, other: Self) -> Self {
162+
<Self as Ord>::max(self, other)
163+
}
164+
}
165+
impl VecOps for i32 {
166+
#[inline(always)]
167+
fn min(self, other: Self) -> Self {
168+
<Self as Ord>::min(self, other)
169+
}
170+
171+
#[inline(always)]
172+
fn max(self, other: Self) -> Self {
173+
<Self as Ord>::max(self, other)
174+
}
175+
}
154176
impl VecOps for i64 {
155177
#[inline(always)]
156178
fn min(self, other: Self) -> Self {
@@ -163,6 +185,18 @@ impl VecOps for i64 {
163185
}
164186
}
165187

188+
impl VecOps for float8::F8E4M3 {
189+
#[inline(always)]
190+
fn min(self, other: Self) -> Self {
191+
Self::min(self, other)
192+
}
193+
194+
#[inline(always)]
195+
fn max(self, other: Self) -> Self {
196+
Self::max(self, other)
197+
}
198+
}
199+
166200
#[inline(always)]
167201
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
168202
if n_threads == 1 {

0 commit comments

Comments
 (0)