@@ -99,7 +99,7 @@ public class DeepSurv<T> : AsyncDecisionTreeRegressionBase<T>
9999 /// <summary>
100100 /// Configuration options.
101101 /// </summary>
102- private readonly DeepSurvOptions _options ;
102+ private readonly DeepSurvOptions < T > _options ;
103103
104104 /// <summary>
105105 /// Random number generator.
@@ -121,11 +121,11 @@ public class DeepSurv<T> : AsyncDecisionTreeRegressionBase<T>
121121 /// </summary>
122122 /// <param name="options">Configuration options.</param>
123123 /// <param name="regularization">Optional regularization.</param>
124- public DeepSurv ( DeepSurvOptions ? options = null , IRegularization < T , Matrix < T > , Vector < T > > ? regularization = null )
124+ public DeepSurv ( DeepSurvOptions < T > ? options = null , IRegularization < T , Matrix < T > , Vector < T > > ? regularization = null )
125125 : base ( null , regularization )
126126 {
127127 _olsIntercept = NumOps . Zero ;
128- _options = options ?? new DeepSurvOptions ( ) ;
128+ _options = options ?? new DeepSurvOptions < T > ( ) ;
129129 _weights = [ ] ;
130130 _biases = [ ] ;
131131 _numFeatures = 0 ;
@@ -383,19 +383,18 @@ private void InitializeNetwork()
383383 var b = _biases [ layer ] ;
384384 int outputSize = w . Columns ;
385385
386+ var weightTensor = Tensor < T > . FromMatrix ( w ) ;
387+ var biasTensor = Tensor < T > . FromVector ( b ) . Reshape ( 1 , outputSize ) ;
386388 var next = new Vector < T > [ n ] ;
387389 for ( int i = 0 ; i < n ; i ++ )
388390 {
389- next [ i ] = new Vector < T > ( outputSize ) ;
390- for ( int j = 0 ; j < outputSize ; j ++ )
391- {
392- // Engine-accelerated dot product for input · weights[:,j]
393- var wCol = new Vector < T > ( current [ i ] . Length ) ;
394- for ( int k = 0 ; k < current [ i ] . Length ; k ++ ) wCol [ k ] = w [ k , j ] ;
395- T sum = NumOps . Add ( b [ j ] , Engine . DotProduct ( current [ i ] , wCol ) ) ;
396-
397- next [ i ] [ j ] = NumOps . FromDouble ( ApplyActivation ( NumOps . ToDouble ( sum ) ) ) ;
398- }
391+ // SIMD: output = input @ weights + biases via Engine.TensorMatMul
392+ var inputTensor = Tensor < T > . FromVector ( current [ i ] ) . Reshape ( 1 , current [ i ] . Length ) ;
393+ var result = Engine . TensorBroadcastAdd (
394+ Engine . TensorMatMul ( inputTensor , weightTensor ) , biasTensor ) ;
395+ // SIMD activation via IActivationFunction.Forward
396+ result = _options . Activation . Forward ( result ) ;
397+ next [ i ] = result . Reshape ( outputSize ) . ToVector ( ) ;
399398 }
400399
401400 hiddenOutputs . Add ( current ) ;
@@ -579,37 +578,6 @@ private T FindTimeForHazard(T targetH0)
579578 }
580579
581580 /// <summary>
582- /// Applies the activation function.
583- /// </summary>
584- private double ApplyActivation ( double x )
585- {
586- return _options . Activation switch
587- {
588- DeepSurvActivation . ReLU => Math . Max ( 0 , x ) ,
589- DeepSurvActivation . SELU => x >= 0 ? 1.0507 * x : 1.0507 * 1.6733 * ( Math . Exp ( x ) - 1 ) ,
590- DeepSurvActivation . ELU => x >= 0 ? x : Math . Exp ( x ) - 1 ,
591- DeepSurvActivation . Tanh => Math . Tanh ( x ) ,
592- DeepSurvActivation . LeakyReLU => x >= 0 ? x : 0.01 * x ,
593- _ => Math . Max ( 0 , x )
594- } ;
595- }
596-
597- /// <summary>
598- /// Applies the activation function derivative.
599- /// </summary>
600- private double ApplyActivationDerivative ( double activated )
601- {
602- return _options . Activation switch
603- {
604- DeepSurvActivation . ReLU => activated > 0 ? 1 : 0 ,
605- DeepSurvActivation . SELU => activated >= 0 ? 1.0507 : 1.0507 * 1.6733 * Math . Exp ( activated / 1.0507 ) ,
606- DeepSurvActivation . ELU => activated >= 0 ? 1 : activated + 1 ,
607- DeepSurvActivation . Tanh => 1 - activated * activated ,
608- DeepSurvActivation . LeakyReLU => activated >= 0 ? 1 : 0.01 ,
609- _ => activated > 0 ? 1 : 0
610- } ;
611- }
612-
613581 private int [ ] GetSortedIndices ( Vector < T > times )
614582 {
615583 return Enumerable . Range ( 0 , times . Length )
@@ -692,7 +660,7 @@ public override byte[] Serialize()
692660 // Options
693661 writer . Write ( _options . NumHiddenLayers ) ;
694662 writer . Write ( _options . HiddenLayerSize ) ;
695- writer . Write ( ( int ) _options . Activation ) ;
663+ writer . Write ( _options . Activation . GetType ( ) . AssemblyQualifiedName ?? _options . Activation . GetType ( ) . FullName ?? _options . Activation . GetType ( ) . Name ) ;
696664 writer . Write ( _numFeatures ) ;
697665
698666 // Weights and biases
@@ -762,7 +730,11 @@ public override void Deserialize(byte[] modelData)
762730
763731 _options . NumHiddenLayers = reader . ReadInt32 ( ) ;
764732 _options . HiddenLayerSize = reader . ReadInt32 ( ) ;
765- _options . Activation = ( DeepSurvActivation ) reader . ReadInt32 ( ) ;
733+ string activationTypeName = reader . ReadString ( ) ;
734+ var activationType = Type . GetType ( activationTypeName ) ;
735+ _options . Activation = activationType is not null
736+ ? ( IActivationFunction < T > ) ( Activator . CreateInstance ( activationType ) ?? new SELUActivation < T > ( ) )
737+ : new SELUActivation < T > ( ) ;
766738 _numFeatures = reader . ReadInt32 ( ) ;
767739
768740 int numLayers = reader . ReadInt32 ( ) ;
0 commit comments