|
| 1 | +use candle_core::{Error, Tensor}; |
| 2 | + |
| 3 | +use crate::math::tensors::Matrix; |
| 4 | + |
| 5 | +pub enum Activations { |
| 6 | + Step, |
| 7 | + Linear, |
| 8 | + Sigmoid, |
| 9 | + Tanh, |
| 10 | + ReLU, |
| 11 | + Softmax, |
| 12 | + Swish, |
| 13 | + GeLU, |
| 14 | + Sinc, |
| 15 | + SeLU, |
| 16 | +} |
| 17 | + |
| 18 | +impl Activations { |
| 19 | + pub fn activate(&self, input: Matrix) -> Result<Matrix, Error> { |
| 20 | + let tensor = input.inner(); |
| 21 | + let device = input.device.clone(); |
| 22 | + let dtype = input.dtype; |
| 23 | + |
| 24 | + let activated = match self { |
| 25 | + Activations::Step => { |
| 26 | + // Heaviside step function: 1.0 where x >= 0.0, else 0.0 |
| 27 | + let zeros = Tensor::zeros_like(tensor)?; |
| 28 | + let ones = Tensor::ones_like(tensor)?; |
| 29 | + tensor.ge(&zeros)?.where_cond(&ones, &zeros)? |
| 30 | + } |
| 31 | + Activations::Linear => tensor.clone(), |
| 32 | + Activations::Tanh => tensor.tanh()?, |
| 33 | + Activations::ReLU => tensor.relu()?, |
| 34 | + Activations::Sinc => { |
| 35 | + // sinc(x) = sin(x) / x, define 1 at x=0 |
| 36 | + // Using a small epsilon to handle division by zero. |
| 37 | + // If x is near zero, output 1, else sin(x)/x |
| 38 | + let eps_val = 1e-7f64; |
| 39 | + let eps = Tensor::full(eps_val, tensor.dims(), &device)?.to_dtype(dtype)?; |
| 40 | + let near_zero = tensor.abs()?.le(&eps)?; |
| 41 | + |
| 42 | + let numerator = tensor.sin()?; |
| 43 | + let denominator = tensor.clone(); // Clone to avoid consuming tensor |
| 44 | + let value = numerator.div(&denominator)?; |
| 45 | + |
| 46 | + near_zero.where_cond(&Tensor::ones_like(tensor)?, &value)? |
| 47 | + } |
| 48 | + Activations::Sigmoid => { |
| 49 | + // Sigmoid(x) = 1 / (1 + exp(-x)) |
| 50 | + let neg_x = tensor.neg()?; |
| 51 | + let exp_neg_x = neg_x.exp()?; |
| 52 | + let one = Tensor::ones_like(&exp_neg_x)?; |
| 53 | + let one_plus_exp_neg_x = one.add(&exp_neg_x)?; |
| 54 | + one_plus_exp_neg_x.recip()? // 1 / (1 + exp(-x)) |
| 55 | + } |
| 56 | + Activations::Softmax => { |
| 57 | + // Softmax(x_i) = exp(x_i) / sum(exp(x_j)) along the last dimension |
| 58 | + // For a Matrix (rank 2), apply along dim 1 (columns) for each row. |
| 59 | + let exp_x = tensor.exp()?; |
| 60 | + // Sum along the last dimension, keeping the dimension for broadcasting |
| 61 | + let sum_exp_x = exp_x.sum_keepdim(1)?; |
| 62 | + exp_x.broadcast_div(&sum_exp_x)? |
| 63 | + } |
| 64 | + Activations::Swish => { |
| 65 | + // Swish(x) = x * Sigmoid(x) |
| 66 | + let neg_x = tensor.neg()?; |
| 67 | + let exp_neg_x = neg_x.exp()?; |
| 68 | + let one = Tensor::ones_like(&exp_neg_x)?; |
| 69 | + let one_plus_exp_neg_x = one.add(&exp_neg_x)?; |
| 70 | + let sigmoid_x = one_plus_exp_neg_x.recip()?; |
| 71 | + tensor.mul(&sigmoid_x)? |
| 72 | + } |
| 73 | + Activations::GeLU => { |
| 74 | + // GeLU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) |
| 75 | + let sqrt_two_val = 2.0f64.sqrt(); |
| 76 | + let sqrt_two = |
| 77 | + Tensor::full(sqrt_two_val, tensor.dims(), &device)?.to_dtype(dtype)?; |
| 78 | + |
| 79 | + let x_div_sqrt_two = tensor.div(&sqrt_two)?; |
| 80 | + let erf_val = x_div_sqrt_two.erf()?; |
| 81 | + let one = Tensor::ones_like(&erf_val)?; |
| 82 | + let one_plus_erf = one.add(&erf_val)?; |
| 83 | + |
| 84 | + let half_val = 0.5f64; |
| 85 | + let half = Tensor::full(half_val, tensor.dims(), &device)?.to_dtype(dtype)?; |
| 86 | + |
| 87 | + tensor.mul(&half)?.mul(&one_plus_erf)? |
| 88 | + } |
| 89 | + Activations::SeLU => { |
| 90 | + // SeLU(x) = lambda * (x if x > 0 else alpha * (exp(x) - 1)) |
| 91 | + // Standard constants for SeLU |
| 92 | + let alpha_val = 1.673_263_242_354_377_2_f64; |
| 93 | + let lambda_val = 1.050_700_987_355_480_5_f64; |
| 94 | + |
| 95 | + let alpha = Tensor::full(alpha_val, tensor.dims(), &device)?.to_dtype(dtype)?; |
| 96 | + let lambda = Tensor::full(lambda_val, tensor.dims(), &device)?.to_dtype(dtype)?; |
| 97 | + let zero = Tensor::zeros_like(tensor)?; |
| 98 | + |
| 99 | + // Condition: x > 0 |
| 100 | + let cond_gt_zero = tensor.gt(&zero)?; |
| 101 | + |
| 102 | + // Case for x > 0: just x |
| 103 | + let case_gt_zero = tensor.clone(); |
| 104 | + |
| 105 | + // Case for x <= 0: alpha * (exp(x) - 1) |
| 106 | + let exp_x = tensor.exp()?; |
| 107 | + let one_for_sub = Tensor::ones_like(&exp_x)?; |
| 108 | + let exp_x_minus_one = exp_x.sub(&one_for_sub)?; |
| 109 | + let case_le_zero = alpha.mul(&exp_x_minus_one)?; |
| 110 | + |
| 111 | + let result = cond_gt_zero.where_cond(&case_gt_zero, &case_le_zero)?; |
| 112 | + lambda.mul(&result)? |
| 113 | + } |
| 114 | + }; |
| 115 | + |
| 116 | + Matrix::new(activated, device, dtype) |
| 117 | + } |
| 118 | +} |
0 commit comments