Skip to content

Commit 08ec84e

Browse files
committed
fix: make MCDropoutLayer thread-safe
1 parent 89f0aae commit 08ec84e

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

src/UncertaintyQuantification/Layers/MCDropoutLayer.cs

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using AiDotNet.LinearAlgebra;
22
using AiDotNet.NeuralNetworks.Layers;
3+
using System.Threading;
34

45
namespace AiDotNet.UncertaintyQuantification.Layers;
56

@@ -27,12 +28,12 @@ namespace AiDotNet.UncertaintyQuantification.Layers;
2728
/// </remarks>
2829
public class MCDropoutLayer<T> : LayerBase<T>
2930
{
30-
private Random _rng;
3131
private readonly double _dropoutRate;
3232
private readonly T _scale;
3333
private readonly int? _initialSeed;
34-
private Tensor<T>? _lastInput;
35-
private Vector<T>? _dropoutMask;
34+
private readonly ThreadLocal<Random> _rng;
35+
private readonly ThreadLocal<Tensor<T>?> _lastInput = new(() => null);
36+
private readonly ThreadLocal<Vector<T>?> _dropoutMask = new(() => null);
3637
private bool _mcMode; // Monte Carlo mode - always apply dropout
3738

3839
/// <summary>
@@ -69,12 +70,20 @@ public MCDropoutLayer(double dropoutRate = 0.5, bool mcMode = false, int? random
6970
_dropoutRate = dropoutRate;
7071
_mcMode = mcMode;
7172
_initialSeed = randomSeed;
72-
_rng = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random();
73+
_rng = new ThreadLocal<Random>(() =>
74+
{
75+
if (_initialSeed.HasValue)
76+
{
77+
return new Random(unchecked(_initialSeed.Value + Thread.CurrentThread.ManagedThreadId));
78+
}
79+
80+
return new Random();
81+
});
7382
}
7483

7584
internal void ResetRng(int seed)
7685
{
77-
_rng = new Random(seed);
86+
_rng.Value = new Random(seed);
7887
}
7988

8089
/// <summary>
@@ -84,30 +93,32 @@ internal void ResetRng(int seed)
8493
/// <returns>The output tensor with dropout applied if in training or MC mode.</returns>
8594
public override Tensor<T> Forward(Tensor<T> input)
8695
{
87-
_lastInput = input;
96+
_lastInput.Value = input;
8897

8998
// Apply dropout if in training mode OR Monte Carlo mode
9099
if (!IsTrainingMode && !_mcMode)
91100
return input;
92101

93102
var inputVector = input.ToVector();
94-
_dropoutMask = new Vector<T>(inputVector.Length);
103+
var mask = new Vector<T>(inputVector.Length);
95104
var outputVector = new Vector<T>(inputVector.Length);
96105

97106
for (int i = 0; i < inputVector.Length; i++)
98107
{
99-
if (_rng.NextDouble() > _dropoutRate)
108+
if (_rng.Value!.NextDouble() > _dropoutRate)
100109
{
101-
_dropoutMask[i] = _scale;
110+
mask[i] = _scale;
102111
outputVector[i] = NumOps.Multiply(inputVector[i], _scale);
103112
}
104113
else
105114
{
106-
_dropoutMask[i] = NumOps.Zero;
115+
mask[i] = NumOps.Zero;
107116
outputVector[i] = NumOps.Zero;
108117
}
109118
}
110119

120+
_dropoutMask.Value = mask;
121+
111122
var outputTensor = Tensor<T>.FromVector(outputVector);
112123
return input.Shape.Length > 1 ? outputTensor.Reshape(input.Shape) : outputTensor;
113124
}
@@ -119,7 +130,7 @@ public override Tensor<T> Forward(Tensor<T> input)
119130
/// <returns>The gradient to pass to the previous layer.</returns>
120131
public override Tensor<T> Backward(Tensor<T> outputGradient)
121132
{
122-
if (_lastInput == null || _dropoutMask == null)
133+
if (_lastInput.Value == null || _dropoutMask.Value == null)
123134
throw new InvalidOperationException("Forward pass must be called before backward pass.");
124135

125136
if (!IsTrainingMode && !_mcMode)
@@ -130,11 +141,11 @@ public override Tensor<T> Backward(Tensor<T> outputGradient)
130141

131142
for (int i = 0; i < outputGradientVector.Length; i++)
132143
{
133-
inputGradientVector[i] = NumOps.Multiply(outputGradientVector[i], _dropoutMask[i]);
144+
inputGradientVector[i] = NumOps.Multiply(outputGradientVector[i], _dropoutMask.Value[i]);
134145
}
135146

136147
var inputGradientTensor = Tensor<T>.FromVector(inputGradientVector);
137-
return _lastInput.Shape.Length > 1 ? inputGradientTensor.Reshape(_lastInput.Shape) : inputGradientTensor;
148+
return _lastInput.Value.Shape.Length > 1 ? inputGradientTensor.Reshape(_lastInput.Value.Shape) : inputGradientTensor;
138149
}
139150

140151
/// <summary>
@@ -169,8 +180,8 @@ public override void SetParameters(Vector<T> parameters)
169180
/// </summary>
170181
public override void ResetState()
171182
{
172-
_lastInput = null;
173-
_dropoutMask = null;
183+
_lastInput.Value = null;
184+
_dropoutMask.Value = null;
174185
}
175186

176187
public override LayerBase<T> Clone()

0 commit comments

Comments
 (0)