|
1 | 1 | <?php |
2 | 2 |
|
| 3 | +declare(strict_types=1); |
| 4 | + |
3 | 5 | namespace Rubix\ML\NeuralNet\ActivationFunctions\SELU; |
4 | 6 |
|
5 | | -use Tensor\Matrix; |
| 7 | +use NumPower; |
| 8 | +use NDArray; |
| 9 | +use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction; |
| 10 | +use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative; |
6 | 11 |
|
7 | 12 | /** |
8 | 13 | * SELU |
|
18 | 23 | * @category Machine Learning |
19 | 24 | * @package Rubix/ML |
20 | 25 | * @author Andrew DalPino |
| 26 | + * @author Samuel Akopyan <leumas.a@gmail.com> |
21 | 27 | */ |
22 | | -class SELU implements ActivationFunction |
| 28 | +class SELU implements ActivationFunction, IBufferDerivative |
23 | 29 | { |
24 | 30 | /** |
25 | 31 | * The value at which leakage starts to saturate. |
26 | 32 | * |
27 | 33 | * @var float |
28 | 34 | */ |
29 | | - public const ALPHA = 1.6732632423543772848170429916717; |
| 35 | + public const ALPHA = 1.6732632; |
30 | 36 |
|
31 | 37 | /** |
32 | 38 | * The scaling coefficient. |
33 | 39 | * |
34 | 40 | * @var float |
35 | 41 | */ |
36 | | - public const SCALE = 1.0507009873554804934193349852946; |
| 42 | + public const LAMBDA = 1.0507009; |
37 | 43 |
|
38 | 44 | /** |
39 | 45 | * The scaling coefficient multiplied by alpha. |
40 | 46 | * |
41 | 47 | * @var float |
42 | 48 | */ |
43 | | - protected const BETA = self::SCALE * self::ALPHA; |
| 49 | + protected const BETA = self::LAMBDA * self::ALPHA; |
44 | 50 |
|
45 | 51 | /** |
46 | 52 | * Compute the activation. |
47 | 53 | * |
48 | | - * @internal |
| 54 | + * f(x) = λ * x if x > 0 |
| 55 | + * f(x) = λ * α * (e^x - 1) if x ≤ 0 |
49 | 56 | * |
50 | | - * @param Matrix $input |
51 | | - * @return Matrix |
| 57 | + * @param NDArray $input The input values |
| 58 | + * @return NDArray The activated values |
52 | 59 | */ |
53 | | - public function activate(Matrix $input) : Matrix |
| 60 | + public function activate(NDArray $input) : NDArray |
54 | 61 | { |
55 | | - $positive = NumPower::maximum($input, 0) * self::SCALE; |
56 | | - $negative = self::BETA * NumPower::expm1($input); |
| 62 | + // Calculate positive part: λ * x for x > 0 |
| 63 | + $positive = NumPower::multiply( |
| 64 | + self::LAMBDA, |
| 65 | + NumPower::maximum($input, 0) |
| 66 | + ); |
| 67 | + |
| 68 | + // Calculate negative part: λ * α * (e^x - 1) for x <= 0 |
| 69 | + $negativeMask = NumPower::minimum($input, 0); |
| 70 | + $negative = NumPower::multiply( |
| 71 | + self::BETA, |
| 72 | + NumPower::expm1($negativeMask) |
| 73 | + ); |
57 | 74 |
|
58 | | - return $negative + $positive; |
| 75 | + // Combine both parts |
| 76 | + return NumPower::add($positive, $negative); |
59 | 77 | } |
60 | 78 |
|
61 | 79 | /** |
62 | | - * Calculate the derivative of the activation. |
| 80 | + * Calculate the derivative of the SELU activation function. |
63 | 81 | * |
64 | | - * @internal |
| 82 | + * f'(x) = λ if x > 0 |
| 83 | + * f'(x) = λ * α * e^x if x ≤ 0 |
65 | 84 | * |
66 | | - * @param Matrix $input |
67 | | - * @param Matrix $output |
68 | | - * @return Matrix |
| 85 | + * @param NDArray $input Input matrix |
| 86 | + * @return NDArray Derivative matrix |
69 | 87 | */ |
70 | | - public function differentiate(Matrix $input, Matrix $output) : Matrix |
| 88 | + public function differentiate(NDArray $input) : NDArray |
71 | 89 | { |
72 | | - $positive = NumPower::greater($output, 0) * self::SCALE; |
73 | | - $negative = NumPower::lessEqual($output) * ($output + self::ALPHA) * self::SCALE; |
| 90 | + // For x > 0: λ |
| 91 | + $positiveMask = NumPower::greater($input, 0); |
| 92 | + $positivePart = NumPower::multiply($positiveMask, self::LAMBDA); |
74 | 93 |
|
75 | | - return $positive + $negative; |
| 94 | + // For x <= 0: λ * α * e^x |
| 95 | + $negativeMask = NumPower::lessEqual($input, 0); |
| 96 | + $negativePart = NumPower::multiply( |
| 97 | + NumPower::multiply( |
| 98 | + NumPower::exp( |
| 99 | + NumPower::multiply($negativeMask, $input) |
| 100 | + ), |
| 101 | + self::BETA |
| 102 | + ), |
| 103 | + $negativeMask |
| 104 | + ); |
| 105 | + |
| 106 | + // Combine both parts |
| 107 | + return NumPower::add($positivePart, $negativePart); |
76 | 108 | } |
77 | 109 |
|
78 | 110 | /** |
79 | | - * Return the string representation of the object. |
80 | | - * |
81 | | - * @internal |
| 111 | + * Return the string representation of the activation function. |
82 | 112 | * |
83 | | - * @return string |
| 113 | + * @return string String representation |
84 | 114 | */ |
85 | 115 | public function __toString() : string |
86 | 116 | { |
|
0 commit comments