Skip to content
10 changes: 8 additions & 2 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@ use candle::{Device, Result, Tensor};
use serde::Deserialize;

#[derive(Debug, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
#[serde(alias = "gelu_pytorch_tanh")]
#[serde(rename = "gelu")]
GeluErf,
#[serde(alias = "gelu_new", alias = "gelu_pytorch_tanh")]
Gelu,
#[serde(alias = "relu")]
Relu,
#[serde(alias = "silu")]
Silu,
#[serde(rename = "swiglu")]
Swiglu,
}

impl HiddenAct {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
match self {
Self::GeluErf => x.gelu_erf(),
Self::Gelu => x.gelu(),
Self::Relu => x.relu(),
Self::Silu => x.silu(),
Expand Down Expand Up @@ -84,6 +89,7 @@ impl Linear {

if let Some(act) = &self.act {
match act {
HiddenAct::GeluErf => x.gelu_erf(),
HiddenAct::Gelu => x.gelu(),
HiddenAct::Relu => x.relu(),
HiddenAct::Silu => x.silu(),
Expand Down
Loading