Skip to content

Commit 9e4a6b2

Browse files
Merge pull request #380 from apphp/SAM-13-softsign
Sam 13 softsign
2 parents 8992e18 + 1cbad5c commit 9e4a6b2

File tree

6 files changed

+515
-46
lines changed

6 files changed

+515
-46
lines changed

docs/neural-network/activation-functions/thresholded-relu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Thresholded ReLU maintains the computational efficiency of standard ReLU while a
2222
## Plots
2323
<img src="../../images/activation-functions/thresholded-relu.png" alt="Thresholded ReLU Function" width="500" height="auto">
2424

25-
<img src="../../images/activation-functions/thresholded-derivative.png" alt="Thresholded ReLU Derivative" width="500" height="auto">
25+
<img src="../../images/activation-functions/thresholded-relu-derivative.png" alt="Thresholded ReLU Derivative" width="500" height="auto">
2626

2727
## Example
2828
```php

src/NeuralNet/ActivationFunctions/Softsign/Softsign.php

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
<?php
22

3-
namespace Rubix\ML\NeuralNet\ActivationFunctions;
3+
declare(strict_types=1);
44

5-
use Tensor\Matrix;
5+
namespace Rubix\ML\NeuralNet\ActivationFunctions\Softsign;
6+
7+
use NumPower;
8+
use NDArray;
9+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
10+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
611

712
/**
813
* Softsign
@@ -17,52 +22,56 @@
1722
* @category Machine Learning
1823
* @package Rubix/ML
1924
* @author Andrew DalPino
25+
* @author Samuel Akopyan <leumas.a@gmail.com>
2026
*/
21-
class Softsign implements ActivationFunction
27+
class Softsign implements ActivationFunction, IBufferDerivative
2228
{
2329
/**
2430
* Compute the activation.
2531
*
26-
* @internal
32+
* f(x) = x / (1 + |x|)
2733
*
28-
* @param Matrix $input
29-
* @return Matrix
34+
* @param NDArray $input
35+
* @return NDArray
3036
*/
31-
public function activate(Matrix $input) : Matrix
37+
public function activate(NDArray $input) : NDArray
3238
{
33-
return $input / (1 + NumPower::abs($input));
39+
// Calculate |x|
40+
$absInput = NumPower::abs($input);
41+
42+
// Calculate 1 + |x|
43+
$denominator = NumPower::add(1.0, $absInput);
44+
45+
// Calculate x / (1 + |x|)
46+
return NumPower::divide($input, $denominator);
3447
}
3548

3649
/**
3750
* Calculate the derivative of the activation.
3851
*
39-
* @internal
52+
* f'(x) = 1 / (1 + |x|)²
4053
*
41-
* @param Matrix $input
42-
* @param Matrix $output
43-
* @return Matrix
54+
* @param NDArray $input
55+
* @return NDArray
4456
*/
45-
public function differentiate(Matrix $input, Matrix $output) : Matrix
57+
public function differentiate(NDArray $input) : NDArray
4658
{
47-
return $input->map([$this, '_differentiate']);
48-
}
59+
// Calculate |x|
60+
$absInput = NumPower::abs($input);
4961

50-
/**
51-
* @internal
52-
*
53-
* @param float $input
54-
* @return float
55-
*/
56-
public function _differentiate(float $input) : float
57-
{
58-
return 1 / (1 + NumPower::abs($input)) ** 2;
62+
// Calculate 1 + |x|
63+
$onePlusAbs = NumPower::add(1.0, $absInput);
64+
65+
// Calculate (1 + |x|)²
66+
$denominator = NumPower::multiply($onePlusAbs, $onePlusAbs);
67+
68+
// Calculate 1 / (1 + |x|)²
69+
return NumPower::divide(1.0, $denominator);
5970
}
6071

6172
/**
6273
* Return the string representation of the object.
6374
*
64-
* @internal
65-
*
6675
* @return string
6776
*/
6877
public function __toString() : string
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions;
6+
7+
use InvalidArgumentException;
8+
9+
/**
10+
* Invalid Threshold Exception
11+
*
12+
* @category Machine Learning
13+
* @package Rubix/ML
14+
* @author Samuel Akopyan <leumas.a@gmail.com>
15+
*/
16+
class InvalidThresholdException extends InvalidArgumentException
17+
{
18+
//
19+
}

src/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLU.php

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
<?php
22

3+
declare(strict_types=1);
4+
35
namespace Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU;
46

5-
use Tensor\Matrix;
6-
use Rubix\ML\Exceptions\InvalidArgumentException;
7+
use NumPower;
8+
use NDArray;
9+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
10+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
11+
use Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions\InvalidThresholdException;
712

813
/**
914
* Thresholded ReLU
@@ -18,8 +23,9 @@
1823
* @category Machine Learning
1924
* @package Rubix/ML
2025
* @author Andrew DalPino
26+
* @author Samuel Akopyan <leumas.a@gmail.com>
2127
*/
22-
class ThresholdedReLU implements ActivationFunction
28+
class ThresholdedReLU implements ActivationFunction, IBufferDerivative
2329
{
2430
/**
2531
* The input value necessary to trigger an activation.
@@ -29,14 +35,17 @@ class ThresholdedReLU implements ActivationFunction
2935
protected float $threshold;
3036

3137
/**
32-
* @param float $threshold
33-
* @throws InvalidArgumentException
38+
* Class constructor.
39+
*
40+
* @param float $threshold The input value necessary to trigger an activation.
41+
* @throws InvalidThresholdException
3442
*/
3543
public function __construct(float $threshold = 1.0)
3644
{
3745
if ($threshold < 0.0) {
38-
throw new InvalidArgumentException('Threshold must be'
39-
. " positive, $threshold given.");
46+
throw new InvalidThresholdException(
47+
message: "Threshold must be positive, $threshold given."
48+
);
4049
}
4150

4251
$this->threshold = $threshold;
@@ -45,35 +54,37 @@ public function __construct(float $threshold = 1.0)
4554
/**
4655
* Compute the activation.
4756
*
48-
* @internal
57+
* f(x) = x if x > threshold, 0 otherwise
4958
*
50-
* @param Matrix $input
51-
* @return Matrix
59+
* @param NDArray $input
60+
* @return NDArray
5261
*/
53-
public function activate(Matrix $input) : Matrix
62+
public function activate(NDArray $input) : NDArray
5463
{
55-
return NumPower::greater($input, $this->threshold) * $input;
64+
// Create a mask where input > threshold
65+
$mask = NumPower::greater($input, $this->threshold);
66+
67+
// Apply the mask to the input
68+
return NumPower::multiply($input, $mask);
5669
}
5770

5871
/**
5972
* Calculate the derivative of the activation.
6073
*
61-
* @internal
74+
* f'(x) = 1 if x > threshold, 0 otherwise
6275
*
63-
* @param Matrix $input
64-
* @param Matrix $output
65-
* @return Matrix
76+
* @param NDArray $input
77+
* @return NDArray
6678
*/
67-
public function differentiate(Matrix $input, Matrix $output) : Matrix
79+
public function differentiate(NDArray $input) : NDArray
6880
{
81+
// The derivative is 1 where input > threshold, 0 otherwise
6982
return NumPower::greater($input, $this->threshold);
7083
}
7184

7285
/**
7386
* Return the string representation of the object.
7487
*
75-
* @internal
76-
*
7788
* @return string
7889
*/
7990
public function __toString() : string

0 commit comments

Comments
 (0)