diff --git a/Cargo.toml b/Cargo.toml index e16f949db6..a837d87a1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,10 +62,7 @@ half = { version = "2.5.0", features = [ "use-intrinsics", "rand_distr", ] } -float8 = { git = "https://github.com/zackangelo/float8", branch = "cudarc_0_16", features = [ - "num-traits", - "rand_distr", -] } +float8 = { version = "0.3", features = ["num-traits", "rand_distr"] } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = [ "jpeg", @@ -82,7 +79,7 @@ parquet = { version = "51.0.0" } rand = "0.9.0" rand_distr = "0.5.1" rayon = "1.7.0" -safetensors = "0.4.1" +safetensors = "0.6" serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 7cd9d2e973..a238cc62d6 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -94,7 +94,7 @@ impl st::View for &Tensor { impl Tensor { pub fn save_safetensors>(&self, name: &str, filename: P) -> Result<()> { let data = [(name, self.clone())]; - Ok(st::serialize_to_file(data, &None, filename.as_ref())?) + Ok(st::serialize_to_file(data, None, filename.as_ref())?) } } @@ -232,6 +232,7 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { st::Dtype::F16 => convert_::(view, device), st::Dtype::F32 => convert_::(view, device), st::Dtype::F64 => convert_::(view, device), + st::Dtype::F8_E4M3 => convert_::(view, device), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -268,7 +269,7 @@ pub fn save + Ord + std::fmt::Display, P: AsRef>( tensors: &HashMap, filename: P, ) -> Result<()> { - Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) + Ok(st::serialize_to_file(tensors, None, filename.as_ref())?) } #[derive(yoke::Yokeable)]