Skip to content

Commit 89f0aae

Browse files
committed
fix: validate BayesianDenseLayer batch dimension
1 parent 14a718c commit 89f0aae

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

src/UncertaintyQuantification/Layers/BayesianDenseLayer.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,19 @@ public override Tensor<T> Forward(Tensor<T> input)
229229
SampleWeights();
230230
}
231231

232-
var batch = input.Rank == 1 ? 1 : input.Shape[0];
233-
if (batch <= 0)
232+
int batch;
233+
if (input.Rank == 1)
234234
{
235235
batch = 1;
236236
}
237+
else
238+
{
239+
batch = input.Shape[0];
240+
if (batch <= 0)
241+
{
242+
throw new ArgumentException("Expected input tensor to have a positive batch dimension (Shape[0]).", nameof(input));
243+
}
244+
}
237245

238246
var expectedLength = batch * _inputSize;
239247
if (input.Length != expectedLength)
@@ -275,11 +283,19 @@ public override Tensor<T> Backward(Tensor<T> outputGradient)
275283
if (_lastInput == null || _sampledWeights == null || _lastPreActivation == null)
276284
throw new InvalidOperationException("Forward pass must be called before backward pass.");
277285

278-
var batch = _lastInput.Rank == 1 ? 1 : _lastInput.Shape[0];
279-
if (batch <= 0)
286+
int batch;
287+
if (_lastInput.Rank == 1)
280288
{
281289
batch = 1;
282290
}
291+
else
292+
{
293+
batch = _lastInput.Shape[0];
294+
if (batch <= 0)
295+
{
296+
throw new ArgumentException("Expected last input tensor to have a positive batch dimension (Shape[0]).", nameof(outputGradient));
297+
}
298+
}
283299

284300
var expectedGradientLength = batch * _outputSize;
285301
if (outputGradient.Length != expectedGradientLength)

0 commit comments

Comments
 (0)