Skip to content

Commit 1cbad5c

Browse files
committed
Refactoring ThresholdedReLU with IBufferDerivative
1 parent e6a6143 commit 1cbad5c

File tree

4 files changed

+290
-20
lines changed

4 files changed

+290
-20
lines changed

docs/neural-network/activation-functions/thresholded-relu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Thresholded ReLU maintains the computational efficiency of standard ReLU while a
2222
## Plots
2323
<img src="../../images/activation-functions/thresholded-relu.png" alt="Thresholded ReLU Function" width="500" height="auto">
2424

25-
<img src="../../images/activation-functions/thresholded-derivative.png" alt="Thresholded ReLU Derivative" width="500" height="auto">
25+
<img src="../../images/activation-functions/thresholded-relu-derivative.png" alt="Thresholded ReLU Derivative" width="500" height="auto">
2626

2727
## Example
2828
```php
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions;
6+
7+
use InvalidArgumentException;
8+
9+
/**
10+
* Invalid Threshold Exception
11+
*
12+
* @category Machine Learning
13+
* @package Rubix/ML
14+
* @author Samuel Akopyan <leumas.a@gmail.com>
15+
*/
16+
class InvalidThresholdException extends InvalidArgumentException
17+
{
18+
//
19+
}

src/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLU.php

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
<?php
22

3+
declare(strict_types=1);
4+
35
namespace Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU;
46

5-
use Tensor\Matrix;
6-
use Rubix\ML\Exceptions\InvalidArgumentException;
7+
use NumPower;
8+
use NDArray;
9+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
10+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
11+
use Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions\InvalidThresholdException;
712

813
/**
914
* Thresholded ReLU
@@ -18,8 +23,9 @@
1823
* @category Machine Learning
1924
* @package Rubix/ML
2025
* @author Andrew DalPino
26+
* @author Samuel Akopyan <leumas.a@gmail.com>
2127
*/
22-
class ThresholdedReLU implements ActivationFunction
28+
class ThresholdedReLU implements ActivationFunction, IBufferDerivative
2329
{
2430
/**
2531
* The input value necessary to trigger an activation.
@@ -29,14 +35,17 @@ class ThresholdedReLU implements ActivationFunction
2935
protected float $threshold;
3036

3137
/**
32-
* @param float $threshold
33-
* @throws InvalidArgumentException
38+
* Class constructor.
39+
*
40+
* @param float $threshold The input value necessary to trigger an activation.
41+
* @throws InvalidThresholdException
3442
*/
3543
public function __construct(float $threshold = 1.0)
3644
{
3745
if ($threshold < 0.0) {
38-
throw new InvalidArgumentException('Threshold must be'
39-
. " positive, $threshold given.");
46+
throw new InvalidThresholdException(
47+
message: "Threshold must be positive, $threshold given."
48+
);
4049
}
4150

4251
$this->threshold = $threshold;
@@ -45,35 +54,37 @@ public function __construct(float $threshold = 1.0)
4554
/**
4655
* Compute the activation.
4756
*
48-
* @internal
57+
* f(x) = x if x > threshold, 0 otherwise
4958
*
50-
* @param Matrix $input
51-
* @return Matrix
59+
* @param NDArray $input
60+
* @return NDArray
5261
*/
53-
public function activate(Matrix $input) : Matrix
62+
public function activate(NDArray $input) : NDArray
5463
{
55-
return NumPower::greater($input, $this->threshold) * $input;
64+
// Create a mask where input > threshold
65+
$mask = NumPower::greater($input, $this->threshold);
66+
67+
// Apply the mask to the input
68+
return NumPower::multiply($input, $mask);
5669
}
5770

5871
/**
5972
* Calculate the derivative of the activation.
6073
*
61-
* @internal
74+
* f'(x) = 1 if x > threshold, 0 otherwise
6275
*
63-
* @param Matrix $input
64-
* @param Matrix $output
65-
* @return Matrix
76+
* @param NDArray $input
77+
* @return NDArray
6678
*/
67-
public function differentiate(Matrix $input, Matrix $output) : Matrix
79+
public function differentiate(NDArray $input) : NDArray
6880
{
81+
// The derivative is 1 where input > threshold, 0 otherwise
6982
return NumPower::greater($input, $this->threshold);
7083
}
7184

7285
/**
7386
* Return the string representation of the object.
7487
*
75-
* @internal
76-
*
7788
* @return string
7889
*/
7990
public function __toString() : string
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
<?php
2+
3+
declare(strict_types = 1);
4+
5+
namespace Rubix\ML\Tests\NeuralNet\ActivationFunctions\ThresholdedReLU;
6+
7+
use Generator;
8+
use NDArray;
9+
use NumPower;
10+
use PHPUnit\Framework\Attributes\CoversClass;
11+
use PHPUnit\Framework\Attributes\DataProvider;
12+
use PHPUnit\Framework\Attributes\Group;
13+
use PHPUnit\Framework\Attributes\Test;
14+
use PHPUnit\Framework\Attributes\TestDox;
15+
use PHPUnit\Framework\TestCase;
16+
use Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\ThresholdedReLU;
17+
use Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions\InvalidThresholdException;
18+
19+
#[Group('ActivationFunctions')]
20+
#[CoversClass(ThresholdedReLU::class)]
21+
class ThresholdedReLUTest extends TestCase
22+
{
23+
/**
24+
* @var ThresholdedReLU
25+
*/
26+
protected ThresholdedReLU $activationFn;
27+
28+
/**
29+
* @var float
30+
*/
31+
protected float $threshold = 1.0;
32+
33+
/**
34+
* @return Generator<array>
35+
*/
36+
public static function computeProvider() : Generator
37+
{
38+
yield [
39+
NumPower::array([
40+
[2.0, 1.0, 0.5, 0.0, -1.0, 1.5, -0.5],
41+
]),
42+
[
43+
[2.0, 0.0, 0.0, 0.0, 0.0, 1.5, 0.0],
44+
],
45+
];
46+
47+
yield [
48+
NumPower::array([
49+
[1.2, 0.31, 1.49],
50+
[0.99, 1.08, 0.03],
51+
[1.05, 0.52, 1.54],
52+
]),
53+
[
54+
[1.2, 0.0, 1.49],
55+
[0.0, 1.08, 0.0],
56+
[1.05, 0.0, 1.54],
57+
],
58+
];
59+
}
60+
61+
/**
62+
* @return Generator<array>
63+
*/
64+
public static function differentiateProvider() : Generator
65+
{
66+
yield [
67+
NumPower::array([
68+
[2.0, 1.0, 0.5, 0.0, -1.0, 1.5, -0.5],
69+
]),
70+
[
71+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
72+
],
73+
];
74+
75+
yield [
76+
NumPower::array([
77+
[1.2, 0.31, 1.49],
78+
[0.99, 1.08, 0.03],
79+
[1.05, 0.52, 1.54],
80+
]),
81+
[
82+
[1.0, 0.0, 1.0],
83+
[0.0, 1.0, 0.0],
84+
[1.0, 0.0, 1.0],
85+
],
86+
];
87+
}
88+
89+
/**
90+
* @return Generator<array>
91+
*/
92+
public static function thresholdValuesProvider() : Generator
93+
{
94+
yield [
95+
0.5,
96+
NumPower::array([
97+
[2.0, 1.0, 0.5, 0.0, -1.0],
98+
]),
99+
[
100+
[2.0, 1.0, 0.0, 0.0, 0.0],
101+
],
102+
[
103+
[1.0, 1.0, 0.0, 0.0, 0.0],
104+
],
105+
];
106+
107+
yield [
108+
2.0,
109+
NumPower::array([
110+
[2.0, 1.0, 3.0, 0.0, 2.5],
111+
]),
112+
[
113+
[0.0, 0.0, 3.0, 0.0, 2.5],
114+
],
115+
[
116+
[0.0, 0.0, 1.0, 0.0, 1.0],
117+
],
118+
];
119+
}
120+
121+
/**
122+
* @return Generator<array>
123+
*/
124+
public static function zeroRegionProvider() : Generator
125+
{
126+
yield [
127+
NumPower::array([[0.0]]),
128+
[[0.0]],
129+
[[0.0]],
130+
];
131+
132+
yield [
133+
NumPower::array([[0.5, 0.9, 0.99, 1.0, 1.01]]),
134+
[[0.0, 0.0, 0.0, 0.0, 1.01]],
135+
[[0.0, 0.0, 0.0, 0.0, 1.0]],
136+
];
137+
}
138+
139+
/**
140+
* @return Generator<array>
141+
*/
142+
public static function extremeValuesProvider() : Generator
143+
{
144+
yield [
145+
NumPower::array([[10.0, 100.0, 1000.0]]),
146+
[[10.0, 100.0, 1000.0]],
147+
[[1.0, 1.0, 1.0]],
148+
];
149+
150+
yield [
151+
NumPower::array([[-10.0, -100.0, -1000.0]]),
152+
[[0.0, 0.0, 0.0]],
153+
[[0.0, 0.0, 0.0]],
154+
];
155+
}
156+
157+
/**
158+
* Set up the test case.
159+
*/
160+
protected function setUp() : void
161+
{
162+
parent::setUp();
163+
164+
$this->activationFn = new ThresholdedReLU($this->threshold);
165+
}
166+
167+
#[Test]
168+
#[TestDox('Can be cast to a string')]
169+
public function testToString() : void
170+
{
171+
static::assertEquals('Thresholded ReLU (threshold: 1)', (string) $this->activationFn);
172+
}
173+
174+
#[Test]
175+
#[TestDox('It throws an exception when threshold is negative')]
176+
public function testInvalidThresholdException() : void
177+
{
178+
$this->expectException(InvalidThresholdException::class);
179+
180+
new ThresholdedReLU(-1.0);
181+
}
182+
183+
#[Test]
184+
#[TestDox('Correctly activates the input')]
185+
#[DataProvider('computeProvider')]
186+
public function testActivate(NDArray $input, array $expected) : void
187+
{
188+
$activations = $this->activationFn->activate($input)->toArray();
189+
190+
static::assertEqualsWithDelta($expected, $activations, 1e-7);
191+
}
192+
193+
#[Test]
194+
#[TestDox('Correctly differentiates the input')]
195+
#[DataProvider('differentiateProvider')]
196+
public function testDifferentiate(NDArray $input, array $expected) : void
197+
{
198+
$derivatives = $this->activationFn->differentiate($input)->toArray();
199+
200+
static::assertEqualsWithDelta($expected, $derivatives, 1e-7);
201+
}
202+
203+
#[Test]
204+
#[TestDox('Correctly handles different threshold values')]
205+
#[DataProvider('thresholdValuesProvider')]
206+
public function testThresholdValues(float $threshold, NDArray $input, array $expectedActivation, array $expectedDerivative) : void
207+
{
208+
$activationFn = new ThresholdedReLU($threshold);
209+
210+
$activations = $activationFn->activate($input)->toArray();
211+
$derivatives = $activationFn->differentiate($input)->toArray();
212+
213+
static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
214+
static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
215+
}
216+
217+
#[Test]
218+
#[TestDox('Correctly handles values around zero')]
219+
#[DataProvider('zeroRegionProvider')]
220+
public function testZeroRegion(NDArray $input, array $expectedActivation, array $expectedDerivative) : void
221+
{
222+
$activations = $this->activationFn->activate($input)->toArray();
223+
$derivatives = $this->activationFn->differentiate($input)->toArray();
224+
225+
static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
226+
static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
227+
}
228+
229+
#[Test]
230+
#[TestDox('Correctly handles extreme values')]
231+
#[DataProvider('extremeValuesProvider')]
232+
public function testExtremeValues(NDArray $input, array $expectedActivation, array $expectedDerivative) : void
233+
{
234+
$activations = $this->activationFn->activate($input)->toArray();
235+
$derivatives = $this->activationFn->differentiate($input)->toArray();
236+
237+
static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
238+
static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
239+
}
240+
}

0 commit comments

Comments
 (0)