diff --git a/backends/candle/src/layers/cublaslt.rs b/backends/candle/src/layers/cublaslt.rs index b1d2b7391..8cc83b238 100644 --- a/backends/candle/src/layers/cublaslt.rs +++ b/backends/candle/src/layers/cublaslt.rs @@ -59,6 +59,7 @@ impl CublasLtWrapper { #[cfg(feature = "cuda")] { let inner_act = match act { + Some(HiddenAct::GeluErf) => Some(Activation::GeluErf), Some(HiddenAct::Gelu) => Some(Activation::Gelu), Some(HiddenAct::Relu) => Some(Activation::Relu), _ => None, @@ -100,6 +101,7 @@ impl CublasLtWrapper { #[cfg(feature = "cuda")] { let inner_act = match act { + Some(HiddenAct::GeluErf) => Some(Activation::GeluErf), Some(HiddenAct::Gelu) => Some(Activation::Gelu), Some(HiddenAct::Relu) => Some(Activation::Relu), _ => None, diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index e15ca8e87..01c5de165 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -3,18 +3,23 @@ use candle::{Device, Result, Tensor}; use serde::Deserialize; #[derive(Debug, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] pub enum HiddenAct { - #[serde(alias = "gelu_pytorch_tanh")] + #[serde(rename = "gelu")] + GeluErf, + #[serde(alias = "gelu_new", alias = "gelu_pytorch_tanh")] Gelu, + #[serde(alias = "relu")] Relu, + #[serde(alias = "silu")] Silu, + #[serde(rename = "swiglu")] Swiglu, } impl HiddenAct { pub fn forward(&self, x: &Tensor) -> Result { match self { + Self::GeluErf => x.gelu_erf(), Self::Gelu => x.gelu(), Self::Relu => x.relu(), Self::Silu => x.silu(), @@ -84,6 +89,7 @@ impl Linear { if let Some(act) = &self.act { match act { + HiddenAct::GeluErf => x.gelu_erf(), HiddenAct::Gelu => x.gelu(), HiddenAct::Relu => x.relu(), HiddenAct::Silu => x.silu(), diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index 1720ce9d1..05b284ad0 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -525,7 +525,7 @@ impl BertSpladeHead { let transform = Linear::new( transform_weight, Some(transform_bias), - Some(HiddenAct::Gelu), + Some(HiddenAct::GeluErf), ); let transform_layer_norm = LayerNorm::load( diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 8748db38a..60d265e58 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -404,7 +404,7 @@ impl NomicMLP { if use_moe { Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) - } else if config.activation_function == HiddenAct::Gelu { + } else if config.activation_function == HiddenAct::GeluErf { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { Ok(Self::GatedMLP(NomicBertGatedMLP::load(vb, config)?)) diff --git a/backends/candle/tests/snapshots/test_bert__emotions_batch.snap b/backends/candle/tests/snapshots/test_bert__emotions_batch.snap index fd582b8c2..bb49b7373 100644 --- a/backends/candle/tests/snapshots/test_bert__emotions_batch.snap +++ b/backends/candle/tests/snapshots/test_bert__emotions_batch.snap @@ -2,87 +2,87 @@ source: backends/candle/tests/test_bert.rs expression: predictions_batch --- -- - -6.548559 - - -6.302024 - - -4.8671727 - - -3.9600255 - - -4.6329865 - - -6.2816987 - - -6.069644 - - -5.7742686 - - -6.9259467 - - -6.1909447 - - -5.67395 - - -6.1698227 - - -7.513461 - - -6.865867 - - -7.186479 - - -7.128109 - - -8.210709 - - -7.0171394 - - -7.1321163 - - -8.533409 - - -6.2294865 - - -8.742306 - - -5.7792044 - - -8.657227 - - -8.258305 - - -6.64832 - - -7.4060283 - - 3.046496 -- - -5.8167515 - - -6.6119466 - - -5.2771955 - - -2.6306503 - - -4.6419163 - - -5.579778 - - -5.797174 - - -6.0305815 - - -5.8720746 - - 0.45377323 - - -3.0235887 - - -5.3944407 - - -5.186683 - - -6.2649117 - - -6.1962767 - - -6.97937 - - -5.5674877 - - -5.521044 - - -5.8899207 - - -4.8699703 - - -5.6259933 - - -7.6109924 - - -4.3881936 - - -6.039008 - - -4.934696 - - -0.6715916 - - -6.399376 - - -2.4499295 -- - -6.548559 - - -6.302024 - - -4.8671727 - - -3.9600255 - - -4.6329865 - - -6.2816987 - - -6.069644 - - -5.7742686 - - -6.9259467 - - -6.1909447 - - -5.67395 - - -6.1698227 - - -7.513461 - - -6.865867 - - -7.186479 - - -7.128109 - - -8.210709 - - -7.0171394 - - -7.1321163 - - -8.533409 - - -6.2294865 - - -8.742306 - - -5.7792044 - - -8.657227 - - -8.258305 - - -6.64832 - - -7.4060283 - - 3.046496 +- - -6.5492845 + - -6.2986374 + - -4.869306 + - -3.9607537 + - -4.635481 + - -6.2853007 + - -6.071209 + - -5.7783217 + - -6.926041 + - -6.1886187 + - -5.6718416 + - -6.1677446 + - -7.511295 + - -6.8649445 + - -7.185084 + - -7.1265144 + - -8.2086735 + - -7.016038 + - -7.1330824 + - -8.532023 + - -6.231272 + - -8.741036 + - -5.777328 + - -8.655965 + - -8.257617 + - -6.6452403 + - -7.4038887 + - 3.0479355 +- - -5.81419 + - -6.609311 + - -5.280827 + - -2.633262 + - -4.64155 + - -5.5816803 + - -5.8019 + - -6.0343924 + - -5.8742776 + - 0.45036578 + - -3.0299604 + - -5.3952494 + - -5.1889625 + - -6.26653 + - -6.2017837 + - -6.978322 + - -5.5669065 + - -5.5190973 + - -5.8894176 + - -4.873868 + - -5.6295815 + - -7.6102467 + - -4.3884554 + - -6.039061 + - -4.936446 + - -0.67068166 + - -6.4012537 + - -2.447851 +- - -6.5492845 + - -6.2986374 + - -4.869306 + - -3.9607537 + - -4.635481 + - -6.2853007 + - -6.071209 + - -5.7783217 + - -6.926041 + - -6.1886187 + - -5.6718416 + - -6.1677446 + - -7.511295 + - -6.8649445 + - -7.185084 + - -7.1265144 + - -8.2086735 + - -7.016038 + - -7.1330824 + - -8.532023 + - -6.231272 + - -8.741036 + - -5.777328 + - -8.655965 + - -8.257617 + - -6.6452403 + - -7.4038887 + - 3.0479355