Skip to content

Commit 130f4bc

Browse files
authored
[US-IF-005]: Implement missing activation functions in ActivationFunctionFactory (#162)
* refactor(US-IF-005): Complete ActivationFunctionFactory mappings for all enum values - Map scalar factory to: ReLU, Sigmoid, Tanh, Linear(Identity), LeakyReLU, ELU, SELU, Softplus, SoftSign, Swish, GELU; Softmax guarded as vector-only - Map vector factory to the same set plus Softmax References: ~/.claude/user-stories/AiDotNet/code_improvements/us-if-005-implement-missing-activation-functions-in-activationfunctionfactory.md * test(US-IF-005): Add ActivationFunctionFactory coverage tests * test(US-IF-005): Add behavior tests for common activation functions (increase coverage) * refactor(US-IF-005): Combine Linear/Identity into single or-pattern mapping per Copilot review
1 parent 6e3376e commit 130f4bc

File tree

3 files changed

+167
-3
lines changed

3 files changed

+167
-3
lines changed

src/Factories/ActivationFunctionFactory.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ public static IActivationFunction<T> CreateActivationFunction(ActivationFunction
4747
return activationFunction switch
4848
{
4949
ActivationFunction.ReLU => new ReLUActivation<T>(),
50-
ActivationFunction.Softmax => throw new NotSupportedException("Softmax is not applicable to single values. Use CreateVectorActivationFunction for Softmax."),
5150
ActivationFunction.Sigmoid => new SigmoidActivation<T>(),
5251
ActivationFunction.Tanh => new TanhActivation<T>(),
53-
ActivationFunction.Identity => new IdentityActivation<T>(),
52+
ActivationFunction.Linear or ActivationFunction.Identity => new IdentityActivation<T>(),
5453
ActivationFunction.LeakyReLU => new LeakyReLUActivation<T>(),
5554
ActivationFunction.ELU => new ELUActivation<T>(),
5655
ActivationFunction.SELU => new SELUActivation<T>(),
56+
ActivationFunction.Softmax => throw new NotSupportedException("Softmax is not applicable to single values. Use CreateVectorActivationFunction for Softmax."),
5757
ActivationFunction.Softplus => new SoftPlusActivation<T>(),
5858
ActivationFunction.SoftSign => new SoftSignActivation<T>(),
5959
ActivationFunction.Swish => new SwishActivation<T>(),
@@ -87,9 +87,10 @@ public static IVectorActivationFunction<T> CreateVectorActivationFunction(Activa
8787
return activationFunction switch
8888
{
8989
ActivationFunction.Softmax => new SoftmaxActivation<T>(),
90+
ActivationFunction.ReLU => new ReLUActivation<T>(),
9091
ActivationFunction.Sigmoid => new SigmoidActivation<T>(),
9192
ActivationFunction.Tanh => new TanhActivation<T>(),
92-
ActivationFunction.Identity => new IdentityActivation<T>(),
93+
ActivationFunction.Linear or ActivationFunction.Identity => new IdentityActivation<T>(),
9394
ActivationFunction.LeakyReLU => new LeakyReLUActivation<T>(),
9495
ActivationFunction.ELU => new ELUActivation<T>(),
9596
ActivationFunction.SELU => new SELUActivation<T>(),
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using System;
2+
using AiDotNet.ActivationFunctions;
3+
using AiDotNet.Enums;
4+
using AiDotNet.Factories;
5+
using AiDotNet.Interfaces;
6+
using Xunit;
7+
8+
namespace AiDotNet.Tests.ActivationFunctions
9+
{
10+
public class ActivationFunctionBehaviorTests
11+
{
12+
private static void AssertClose(double actual, double expected, double tol = 1e-6)
13+
{
14+
Assert.True(Math.Abs(actual - expected) <= tol, $"Actual {actual} != Expected {expected}");
15+
}
16+
17+
[Fact]
18+
public void ReLU_Activate_And_Derivative()
19+
{
20+
var fn = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.ReLU);
21+
AssertClose(fn.Activate(-1.0), 0.0);
22+
AssertClose(fn.Activate(2.5), 2.5);
23+
AssertClose(fn.Derivative(-1.0), 0.0);
24+
}
25+
26+
[Fact]
27+
public void Sigmoid_Activate_And_Derivative()
28+
{
29+
var fn = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.Sigmoid);
30+
var y = fn.Activate(0.0);
31+
AssertClose(y, 0.5);
32+
var dy = fn.Derivative(0.0);
33+
Assert.True(dy > 0.0 && dy < 0.3);
34+
}
35+
36+
[Fact]
37+
public void Tanh_Activate_And_Derivative()
38+
{
39+
var fn = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.Tanh);
40+
AssertClose(fn.Activate(0.0), 0.0);
41+
}
42+
43+
[Fact]
44+
public void Identity_Activate()
45+
{
46+
var fn = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.Linear);
47+
AssertClose(fn.Activate(3.14), 3.14);
48+
}
49+
50+
[Fact]
51+
public void LeakyRelu_Activate()
52+
{
53+
var fn = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.LeakyReLU);
54+
Assert.True(fn.Activate(-2.0) < 0 && fn.Activate(2.0) > 0);
55+
}
56+
57+
[Fact]
58+
public void ELU_SELU_Activate()
59+
{
60+
var elu = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.ELU);
61+
var selu = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.SELU);
62+
Assert.True(elu.Activate(-1.0) < 0.0);
63+
Assert.True(selu.Activate(-1.0) < 0.0);
64+
}
65+
66+
[Fact]
67+
public void Softplus_SoftSign_Swish_GELU()
68+
{
69+
var sp = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.Softplus);
70+
var ss = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.SoftSign);
71+
var sw = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.Swish);
72+
var ge = ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.GELU);
73+
Assert.True(sp.Activate(1.0) > 0.0);
74+
Assert.True(ss.Activate(1.0) > 0.0 && ss.Activate(1.0) <= 1.0);
75+
Assert.True(sw.Activate(1.0) > 0.0);
76+
Assert.True(ge.Activate(1.0) > 0.0);
77+
}
78+
79+
[Fact]
80+
public void Vector_Softmax_Factory()
81+
{
82+
var vfn = ActivationFunctionFactory<double>.CreateVectorActivationFunction(ActivationFunction.Softmax);
83+
Assert.IsAssignableFrom<IVectorActivationFunction<double>>(vfn);
84+
}
85+
}
86+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using System;
2+
using System.Linq;
3+
using AiDotNet.Enums;
4+
using AiDotNet.Factories;
5+
using AiDotNet.Interfaces;
6+
using Xunit;
7+
8+
namespace AiDotNet.Tests.Factories
9+
{
10+
public class ActivationFunctionFactoryTests
11+
{
12+
[Fact]
13+
public void CreateActivationFunction_Returns_For_Scalar_Compatible()
14+
{
15+
// Scalar-compatible functions in current enum
16+
var scalarValues = new[]
17+
{
18+
ActivationFunction.ReLU,
19+
ActivationFunction.Sigmoid,
20+
ActivationFunction.Tanh,
21+
ActivationFunction.Linear,
22+
ActivationFunction.LeakyReLU,
23+
ActivationFunction.ELU,
24+
ActivationFunction.SELU,
25+
ActivationFunction.Softplus,
26+
ActivationFunction.SoftSign,
27+
ActivationFunction.Swish,
28+
ActivationFunction.GELU,
29+
ActivationFunction.Identity
30+
};
31+
32+
foreach (var af in scalarValues)
33+
{
34+
var fn = ActivationFunctionFactory<double>.CreateActivationFunction(af);
35+
Assert.NotNull(fn);
36+
Assert.IsAssignableFrom<IActivationFunction<double>>(fn);
37+
}
38+
}
39+
40+
[Fact]
41+
public void CreateActivationFunction_Throws_For_Softmax_Scalar()
42+
{
43+
Assert.Throws<NotSupportedException>(() =>
44+
ActivationFunctionFactory<double>.CreateActivationFunction(ActivationFunction.Softmax));
45+
}
46+
47+
[Fact]
48+
public void CreateVectorActivationFunction_Returns_For_Vector_Compatible()
49+
{
50+
// Vector-compatible functions in current enum (includes Softmax)
51+
var vectorValues = new[]
52+
{
53+
ActivationFunction.Softmax,
54+
ActivationFunction.ReLU,
55+
ActivationFunction.Sigmoid,
56+
ActivationFunction.Tanh,
57+
ActivationFunction.Linear,
58+
ActivationFunction.LeakyReLU,
59+
ActivationFunction.ELU,
60+
ActivationFunction.SELU,
61+
ActivationFunction.Softplus,
62+
ActivationFunction.SoftSign,
63+
ActivationFunction.Swish,
64+
ActivationFunction.GELU,
65+
ActivationFunction.Identity
66+
};
67+
68+
foreach (var af in vectorValues)
69+
{
70+
var fn = ActivationFunctionFactory<double>.CreateVectorActivationFunction(af);
71+
Assert.NotNull(fn);
72+
Assert.IsAssignableFrom<IVectorActivationFunction<double>>(fn);
73+
}
74+
}
75+
}
76+
}
77+

0 commit comments

Comments
 (0)