Skip to content

Commit 5a1d2c4

Browse files
Merge pull request #378 from apphp/SAM-11-SiLU
Sam 11 si lu
2 parents 9119f26 + 512719f commit 5a1d2c4

File tree

8 files changed

+291
-134
lines changed

8 files changed

+291
-134
lines changed

src/NeuralNet/ActivationFunctions/GELU/GELU.php

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@ class GELU implements ActivationFunction, IBufferDerivative
3333
* @var float
3434
*/
3535
protected const ALPHA = 0.7978845608;
36+
/** @var float 0.5 * ALPHA */
37+
protected const HALF_ALPHA = 0.3989422804;
3638

3739
/**
3840
* Gaussian error function approximation term.
3941
*
4042
* @var float
4143
*/
4244
protected const BETA = 0.044715;
45+
/** @var float 3 * BETA */
46+
protected const TRIPLE_BETA = 0.134145;
4347

4448
/**
4549
* Apply the GeLU activation function to the input.
@@ -57,21 +61,21 @@ public function activate(NDArray $input) : NDArray
5761
// Calculate inner term: x + BETA * x^3
5862
$innerTerm = NumPower::add(
5963
$input,
60-
NumPower::multiply(self::BETA, $cubed)
64+
NumPower::multiply($cubed, self::BETA)
6165
);
6266

6367
// Apply tanh(ALPHA * innerTerm)
6468
$tanhTerm = NumPower::tanh(
65-
NumPower::multiply(self::ALPHA, $innerTerm)
69+
NumPower::multiply($innerTerm, self::ALPHA)
6670
);
6771

6872
// Calculate 1 + tanhTerm
6973
$onePlusTanh = NumPower::add(1.0, $tanhTerm);
7074

7175
// Calculate 0.5 * x * (1 + tanhTerm)
7276
return NumPower::multiply(
73-
0.5,
74-
NumPower::multiply($input, $onePlusTanh)
77+
NumPower::multiply($input, $onePlusTanh),
78+
0.5
7579
);
7680
}
7781

@@ -97,11 +101,11 @@ public function differentiate(NDArray $input) : NDArray
97101

98102
// Calculate inner term: ALPHA * (x + BETA * x^3)
99103
$innerTerm = NumPower::multiply(
100-
self::ALPHA,
101104
NumPower::add(
102105
$input,
103-
NumPower::multiply(self::BETA, $cubed)
104-
)
106+
NumPower::multiply($cubed, self::BETA)
107+
),
108+
self::ALPHA
105109
);
106110

107111
// Calculate cosh and sech^2
@@ -113,24 +117,24 @@ public function differentiate(NDArray $input) : NDArray
113117

114118
// Calculate 0.5 * (1 + tanh(innerTerm))
115119
$firstTerm = NumPower::multiply(
116-
0.5,
117-
NumPower::add(1.0, NumPower::tanh($innerTerm))
120+
NumPower::add(1.0, NumPower::tanh($innerTerm)),
121+
0.5
118122
);
119123

120124
// Calculate 0.5 * x * sech^2 * ALPHA * (1 + 3 * BETA * x^2)
121125
$secondTerm = NumPower::multiply(
122126
NumPower::multiply(
123127
NumPower::multiply(
124-
0.5 * self::ALPHA,
125-
$input
128+
$input,
129+
self::HALF_ALPHA
126130
),
127131
$sech2
128132
),
129133
NumPower::add(
130134
1.0,
131135
NumPower::multiply(
132-
3.0 * self::BETA,
133-
NumPower::pow($input, 2)
136+
NumPower::pow($input, 2),
137+
self::TRIPLE_BETA
134138
)
135139
)
136140
);

src/NeuralNet/ActivationFunctions/HardSigmoid/HardSigmoid.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public function activate(NDArray $input) : NDArray
6565
{
6666
// Calculate 0.2 * x + 0.5
6767
$linear = NumPower::add(
68-
NumPower::multiply(self::SLOPE, $input),
68+
NumPower::multiply($input, self::SLOPE),
6969
self::INTERCEPT
7070
);
7171

src/NeuralNet/ActivationFunctions/SELU/SELU.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ public function activate(NDArray $input) : NDArray
6161
{
6262
// Calculate positive part: λ * x for x > 0
6363
$positive = NumPower::multiply(
64-
self::LAMBDA,
65-
NumPower::maximum($input, 0)
64+
NumPower::maximum($input, 0),
65+
self::LAMBDA
6666
);
6767

6868
// Calculate negative part: λ * α * (e^x - 1) for x <= 0
6969
$negativeMask = NumPower::minimum($input, 0);
7070
$negative = NumPower::multiply(
71-
self::BETA,
72-
NumPower::expm1($negativeMask)
71+
NumPower::expm1($negativeMask),
72+
self::BETA
7373
);
7474

7575
// Combine both parts

src/NeuralNet/ActivationFunctions/SiLU/SiLU.php

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
<?php
22

3+
declare(strict_types=1);
4+
35
namespace Rubix\ML\NeuralNet\ActivationFunctions\SiLU;
46

5-
use Tensor\Matrix;
7+
use NumPower;
8+
use NDArray;
9+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
10+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
11+
use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid\Sigmoid;
612

713
/**
814
* SiLU
@@ -17,34 +23,64 @@
1723
* @category Machine Learning
1824
* @package Rubix/ML
1925
* @author Andrew DalPino
26+
* @author Samuel Akopyan <leumas.a@gmail.com>
2027
*/
21-
class SiLU implements ActivationFunction
28+
class SiLU implements ActivationFunction, IBufferDerivative
2229
{
30+
/**
31+
* The Sigmoid activation function.
32+
*
33+
* @var Sigmoid
34+
*/
35+
protected Sigmoid $sigmoid;
36+
37+
/**
38+
* Class constructor.
39+
*/
40+
public function __construct()
41+
{
42+
$this->sigmoid = new Sigmoid();
43+
}
44+
2345
/**
2446
* Compute the activation.
2547
*
26-
* @internal
48+
* f(x) = x * sigmoid(x) = x / (1 + e^(-x))
2749
*
28-
* @param Matrix $input
29-
* @return Matrix
50+
* @param NDArray $input
51+
* @return NDArray
3052
*/
31-
public function activate(Matrix $input) : Matrix
53+
public function activate(NDArray $input) : NDArray
3254
{
33-
return $input / (1.0 + NumPower::exp(-$input));
55+
// Calculate sigmoid(x) using the Sigmoid activation function
56+
$sigmoid = $this->sigmoid->activate($input);
57+
58+
// Calculate x * sigmoid(x)
59+
return NumPower::multiply($input, $sigmoid);
3460
}
3561

3662
/**
3763
* Calculate the derivative of the activation.
3864
*
39-
* @internal
65+
* f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
66+
* = sigmoid(x) + x * sigmoid'(x)
4067
*
41-
* @param Matrix $input
42-
* @param Matrix $output
43-
* @return Matrix
68+
* @param NDArray $input Input matrix
69+
* @return NDArray Derivative matrix
4470
*/
45-
public function differentiate(Matrix $input, Matrix $output) : Matrix
71+
public function differentiate(NDArray $input) : NDArray
4672
{
47-
return $output / $input * NumPower::ones($output->shape()) / $output * 2;
73+
// Calculate sigmoid(x) using the Sigmoid activation function
74+
$sigmoid = $this->sigmoid->activate($input);
75+
76+
// Calculate sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
77+
$sigmoidDerivative = $this->sigmoid->differentiate($sigmoid);
78+
79+
// Calculate x * sigmoid'(x)
80+
$xTimesSigmoidDerivative = NumPower::multiply($input, $sigmoidDerivative);
81+
82+
// Calculate sigmoid(x) + x * sigmoid'(x)
83+
return NumPower::add($sigmoid, $xTimesSigmoidDerivative);
4884
}
4985

5086
/**

src/NeuralNet/ActivationFunctions/Sigmoid/Sigmoid.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ class Sigmoid implements ActivationFunction, OBufferDerivative
3434
public function activate(NDArray $input) : NDArray
3535
{
3636
// Calculate e^(-x)
37-
$negExp = NumPower::exp(NumPower::multiply(-1.0, $input));
38-
37+
$negExp = NumPower::exp(NumPower::multiply($input, -1.0));
38+
3939
// Calculate 1 + e^(-x)
4040
$denominator = NumPower::add(1.0, $negExp);
41-
41+
4242
// Calculate 1 / (1 + e^(-x))
4343
return NumPower::divide(1.0, $denominator);
4444
}
@@ -57,7 +57,7 @@ public function differentiate(NDArray $output) : NDArray
5757
{
5858
// Calculate (1 - output)
5959
$oneMinusOutput = NumPower::subtract(1.0, $output);
60-
60+
6161
// Calculate output * (1 - output)
6262
return NumPower::multiply($output, $oneMinusOutput);
6363
}

tests/NeuralNet/ActivationFunctions/SELU/SELUTest.php

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
#[CoversClass(SELU::class)]
2020
class SELUTest extends TestCase
2121
{
22-
/**
23-
* @var SELU
24-
*/
25-
protected SELU $activationFn;
26-
2722
/**
2823
* The value at which leakage starts to saturate.
2924
*
@@ -45,6 +40,11 @@ class SELUTest extends TestCase
4540
*/
4641
protected const BETA = self::LAMBDA * self::ALPHA;
4742

43+
/**
44+
* @var SELU
45+
*/
46+
protected SELU $activationFn;
47+
4848
/**
4949
* @return Generator<array>
5050
*/
@@ -55,14 +55,7 @@ public static function computeProvider() : Generator
5555
[2.0, 1.0, -0.5, 0.0, 20.0, -10.0],
5656
]),
5757
[
58-
[
59-
2.10140180,
60-
1.05070090,
61-
-0.6917580,
62-
0.0,
63-
21.0140190,
64-
-1.7580193
65-
],
58+
[2.10140180, 1.05070090, -0.6917580, 0.0, 21.0140190, -1.7580193],
6659
],
6760
];
6861

@@ -76,17 +69,17 @@ public static function computeProvider() : Generator
7669
[
7770
self::BETA * (exp(-0.12) - 1.0),
7871
0.31 * self::LAMBDA,
79-
self::BETA * (exp(-0.49) - 1.0)
72+
self::BETA * (exp(-0.49) - 1.0),
8073
],
8174
[
8275
0.99 * self::LAMBDA,
8376
0.08 * self::LAMBDA,
84-
self::BETA * (exp(-0.03) - 1.0)
77+
self::BETA * (exp(-0.03) - 1.0),
8578
],
8679
[
8780
0.05 * self::LAMBDA,
8881
self::BETA * (exp(-0.52) - 1.0),
89-
0.54 * self::LAMBDA
82+
0.54 * self::LAMBDA,
9083
],
9184
],
9285
];
@@ -102,15 +95,7 @@ public static function differentiateProvider() : Generator
10295
[2.0, 1.0, -0.5, 0.0, 20.0, -10.0, -20],
10396
]),
10497
[
105-
[
106-
self::LAMBDA,
107-
self::LAMBDA,
108-
1.0663410,
109-
1.7580991,
110-
self::LAMBDA,
111-
0.0000798,
112-
0.0
113-
],
98+
[self::LAMBDA, self::LAMBDA, 1.0663410, 1.7580991, self::LAMBDA, 0.0000798, 0.0],
11499
],
115100
];
116101

@@ -121,21 +106,9 @@ public static function differentiateProvider() : Generator
121106
[0.05, -0.52, 0.54],
122107
]),
123108
[
124-
[
125-
self::BETA * exp(-0.12),
126-
self::LAMBDA,
127-
self::BETA * exp(-0.49)
128-
],
129-
[
130-
self::LAMBDA,
131-
self::LAMBDA,
132-
self::BETA * exp(-0.03)
133-
],
134-
[
135-
self::LAMBDA,
136-
self::BETA * exp(-0.52),
137-
self::LAMBDA
138-
],
109+
[self::BETA * exp(-0.12), self::LAMBDA, self::BETA * exp(-0.49)],
110+
[self::LAMBDA, self::LAMBDA, self::BETA * exp(-0.03)],
111+
[self::LAMBDA, self::BETA * exp(-0.52), self::LAMBDA],
139112
],
140113
];
141114
}
@@ -163,18 +136,10 @@ public static function zeroRegionProvider() : Generator
163136
yield [
164137
NumPower::array([[-1e-15, -1e-10, -1e-7]]),
165138
[
166-
[
167-
self::BETA * (exp(-1e-15) - 1.0),
168-
self::BETA * (exp(-1e-10) - 1.0),
169-
self::BETA * (exp(-1e-7) - 1.0),
170-
],
139+
[self::BETA * (exp(-1e-15) - 1.0), self::BETA * (exp(-1e-10) - 1.0), self::BETA * (exp(-1e-7) - 1.0)],
171140
],
172141
[
173-
[
174-
self::BETA * exp(-1e-15),
175-
self::BETA * exp(-1e-10),
176-
self::BETA * exp(-1e-7),
177-
],
142+
[self::BETA * exp(-1e-15), self::BETA * exp(-1e-10), self::BETA * exp(-1e-7)],
178143
],
179144
];
180145

0 commit comments

Comments
 (0)