Skip to content

Commit 7077765

Browse files
ooplesclaude
andcommitted
refactor: replace DeepSurvActivation enum with IActivationFunction<T>
Same refactor as DeepHit: - DeepSurvOptions is now generic DeepSurvOptions<T> - Activation property is IActivationFunction<T> (default SELU) - Delete DeepSurvActivation enum entirely - Delete ApplyActivation/ApplyActivationDerivative scalar switch methods - Forward pass now uses Engine.TensorMatMul + activation.Forward (SIMD) - Serialization uses AssemblyQualifiedName + Activator.CreateInstance Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9293ed1 commit 7077765

File tree

2 files changed

+25
-85
lines changed

2 files changed

+25
-85
lines changed

src/Models/Options/DeepSurvOptions.cs

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
using AiDotNet.ActivationFunctions;
2+
using AiDotNet.Interfaces;
3+
14
namespace AiDotNet.Models.Options;
25

36
/// <summary>
47
/// Configuration options for DeepSurv survival analysis model.
58
/// </summary>
9+
/// <typeparam name="T">The numeric type used for calculations.</typeparam>
610
/// <remarks>
711
/// <para>
812
/// DeepSurv extends the Cox Proportional Hazards model using a deep neural network
@@ -16,17 +20,12 @@ namespace AiDotNet.Models.Options;
1620
/// - How long until a customer cancels their subscription?
1721
/// - How long until a patient experiences disease recurrence?
1822
///
19-
/// What makes survival analysis special is that some observations are "censored" -
20-
/// meaning the event hasn't happened yet when the study ends. For example, if you're
21-
/// studying customer churn and a customer is still subscribed when the study ends,
22-
/// you know they survived at least that long, but you don't know when (or if) they'll churn.
23-
///
2423
/// DeepSurv uses a neural network to learn complex patterns in your data while properly
2524
/// handling this censoring. It outputs a "risk score" - higher values mean higher risk
2625
/// of the event happening sooner.
2726
/// </para>
2827
/// </remarks>
29-
public class DeepSurvOptions
28+
public class DeepSurvOptions<T>
3029
{
3130
/// <summary>
3231
/// Gets or sets the number of hidden layers.
@@ -71,10 +70,10 @@ public class DeepSurvOptions
7170
public double L2Regularization { get; set; } = 0.001;
7271

7372
/// <summary>
74-
/// Gets or sets the activation function type.
73+
/// Gets or sets the activation function for hidden layers.
7574
/// </summary>
7675
/// <value>Default is SELU.</value>
77-
public DeepSurvActivation Activation { get; set; } = DeepSurvActivation.SELU;
76+
public IActivationFunction<T> Activation { get; set; } = new SELUActivation<T>();
7877

7978
/// <summary>
8079
/// Gets or sets whether to use batch normalization.
@@ -93,34 +92,3 @@ public class DeepSurvOptions
9392
/// <value>Default is 10. Set to null to disable early stopping.</value>
9493
public int? EarlyStoppingPatience { get; set; } = 10;
9594
}
96-
97-
/// <summary>
98-
/// Activation functions for DeepSurv.
99-
/// </summary>
100-
public enum DeepSurvActivation
101-
{
102-
/// <summary>
103-
/// Rectified Linear Unit: max(0, x).
104-
/// </summary>
105-
ReLU,
106-
107-
/// <summary>
108-
/// Scaled Exponential Linear Unit - self-normalizing.
109-
/// </summary>
110-
SELU,
111-
112-
/// <summary>
113-
/// Exponential Linear Unit.
114-
/// </summary>
115-
ELU,
116-
117-
/// <summary>
118-
/// Hyperbolic tangent.
119-
/// </summary>
120-
Tanh,
121-
122-
/// <summary>
123-
/// Leaky ReLU with small negative slope.
124-
/// </summary>
125-
LeakyReLU
126-
}

src/Regression/DeepSurv.cs

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public class DeepSurv<T> : AsyncDecisionTreeRegressionBase<T>
9999
/// <summary>
100100
/// Configuration options.
101101
/// </summary>
102-
private readonly DeepSurvOptions _options;
102+
private readonly DeepSurvOptions<T> _options;
103103

104104
/// <summary>
105105
/// Random number generator.
@@ -121,11 +121,11 @@ public class DeepSurv<T> : AsyncDecisionTreeRegressionBase<T>
121121
/// </summary>
122122
/// <param name="options">Configuration options.</param>
123123
/// <param name="regularization">Optional regularization.</param>
124-
public DeepSurv(DeepSurvOptions? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)
124+
public DeepSurv(DeepSurvOptions<T>? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)
125125
: base(null, regularization)
126126
{
127127
_olsIntercept = NumOps.Zero;
128-
_options = options ?? new DeepSurvOptions();
128+
_options = options ?? new DeepSurvOptions<T>();
129129
_weights = [];
130130
_biases = [];
131131
_numFeatures = 0;
@@ -383,19 +383,18 @@ private void InitializeNetwork()
383383
var b = _biases[layer];
384384
int outputSize = w.Columns;
385385

386+
var weightTensor = Tensor<T>.FromMatrix(w);
387+
var biasTensor = Tensor<T>.FromVector(b).Reshape(1, outputSize);
386388
var next = new Vector<T>[n];
387389
for (int i = 0; i < n; i++)
388390
{
389-
next[i] = new Vector<T>(outputSize);
390-
for (int j = 0; j < outputSize; j++)
391-
{
392-
// Engine-accelerated dot product for input · weights[:,j]
393-
var wCol = new Vector<T>(current[i].Length);
394-
for (int k = 0; k < current[i].Length; k++) wCol[k] = w[k, j];
395-
T sum = NumOps.Add(b[j], Engine.DotProduct(current[i], wCol));
396-
397-
next[i][j] = NumOps.FromDouble(ApplyActivation(NumOps.ToDouble(sum)));
398-
}
391+
// SIMD: output = input @ weights + biases via Engine.TensorMatMul
392+
var inputTensor = Tensor<T>.FromVector(current[i]).Reshape(1, current[i].Length);
393+
var result = Engine.TensorBroadcastAdd(
394+
Engine.TensorMatMul(inputTensor, weightTensor), biasTensor);
395+
// SIMD activation via IActivationFunction.Forward
396+
result = _options.Activation.Forward(result);
397+
next[i] = result.Reshape(outputSize).ToVector();
399398
}
400399

401400
hiddenOutputs.Add(current);
@@ -579,37 +578,6 @@ private T FindTimeForHazard(T targetH0)
579578
}
580579

581580
/// <summary>
582-
/// Applies the activation function.
583-
/// </summary>
584-
private double ApplyActivation(double x)
585-
{
586-
return _options.Activation switch
587-
{
588-
DeepSurvActivation.ReLU => Math.Max(0, x),
589-
DeepSurvActivation.SELU => x >= 0 ? 1.0507 * x : 1.0507 * 1.6733 * (Math.Exp(x) - 1),
590-
DeepSurvActivation.ELU => x >= 0 ? x : Math.Exp(x) - 1,
591-
DeepSurvActivation.Tanh => Math.Tanh(x),
592-
DeepSurvActivation.LeakyReLU => x >= 0 ? x : 0.01 * x,
593-
_ => Math.Max(0, x)
594-
};
595-
}
596-
597-
/// <summary>
598-
/// Applies the activation function derivative.
599-
/// </summary>
600-
private double ApplyActivationDerivative(double activated)
601-
{
602-
return _options.Activation switch
603-
{
604-
DeepSurvActivation.ReLU => activated > 0 ? 1 : 0,
605-
DeepSurvActivation.SELU => activated >= 0 ? 1.0507 : 1.0507 * 1.6733 * Math.Exp(activated / 1.0507),
606-
DeepSurvActivation.ELU => activated >= 0 ? 1 : activated + 1,
607-
DeepSurvActivation.Tanh => 1 - activated * activated,
608-
DeepSurvActivation.LeakyReLU => activated >= 0 ? 1 : 0.01,
609-
_ => activated > 0 ? 1 : 0
610-
};
611-
}
612-
613581
private int[] GetSortedIndices(Vector<T> times)
614582
{
615583
return Enumerable.Range(0, times.Length)
@@ -692,7 +660,7 @@ public override byte[] Serialize()
692660
// Options
693661
writer.Write(_options.NumHiddenLayers);
694662
writer.Write(_options.HiddenLayerSize);
695-
writer.Write((int)_options.Activation);
663+
writer.Write(_options.Activation.GetType().AssemblyQualifiedName ?? _options.Activation.GetType().FullName ?? _options.Activation.GetType().Name);
696664
writer.Write(_numFeatures);
697665

698666
// Weights and biases
@@ -762,7 +730,11 @@ public override void Deserialize(byte[] modelData)
762730

763731
_options.NumHiddenLayers = reader.ReadInt32();
764732
_options.HiddenLayerSize = reader.ReadInt32();
765-
_options.Activation = (DeepSurvActivation)reader.ReadInt32();
733+
string activationTypeName = reader.ReadString();
734+
var activationType = Type.GetType(activationTypeName);
735+
_options.Activation = activationType is not null
736+
? (IActivationFunction<T>)(Activator.CreateInstance(activationType) ?? new SELUActivation<T>())
737+
: new SELUActivation<T>();
766738
_numFeatures = reader.ReadInt32();
767739

768740
int numLayers = reader.ReadInt32();

0 commit comments

Comments
 (0)