Skip to content

Commit ac14aa5

Browse files
authored
Merge pull request #68 from bruker-biosensors/feature/65_Support_analytical_hessian
Revise cost function validation
2 parents f395819 + 85ad7f2 commit ac14aa5

14 files changed

+93
-146
lines changed

minuit2.UnitTests/Any_gradient_based_minimizer.spec.cs

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,7 @@ public abstract class Any_gradient_based_minimizer(IMinimizer minimizer) : Any_m
1414
private readonly ConfigurableLeastSquaresProblem _defaultProblem = new CubicPolynomialLeastSquaresProblem();
1515

1616
[Test]
17-
public void when_asked_to_minimize_a_cost_function_with_an_analytical_gradient_throwing_an_exception_for_the_given_parameter_configurations_throws_a_cost_function_error()
18-
{
19-
var cost = new ModelEvaluatingCostFunction(1, ["offset", "slope"], (x, p) => p[0] + p[1] * x,
20-
modelGradient: (_, _) => throw new TestException());
21-
var parameterConfigurations = new[] { Variable("offset", 1), Variable("slope", 1) };
22-
23-
Action action = () => _minimizer.Minimize(cost, parameterConfigurations);
24-
25-
action.Should().Throw<CostFunctionError>().WithMessage("*gradient*").WithInnerException<TestException>();
26-
}
27-
28-
[Test]
29-
public void when_asked_to_minimize_a_cost_function_with_an_analytical_hessian_throwing_an_exception_for_the_given_parameter_configurations_throws_a_cost_function_error()
30-
{
31-
var cost = new ModelEvaluatingCostFunction(1, ["offset", "slope"], (x, p) => p[0] + p[1] * x,
32-
modelHessian: (_, _) => throw new TestException());
33-
var parameterConfigurations = new[] { Variable("offset", 1), Variable("slope", 1) };
34-
35-
Action action = () => _minimizer.Minimize(cost, parameterConfigurations);
36-
37-
action.Should().Throw<CostFunctionError>().WithMessage("*Hessian*").WithInnerException<TestException>();
38-
}
39-
40-
[Test]
41-
public void when_asked_to_minimize_a_cost_function_with_an_analytical_hessian_diagonal_throwing_an_exception_for_the_given_parameter_configurations_throws_a_cost_function_error()
42-
{
43-
var cost = new ModelEvaluatingCostFunction(1, ["offset", "slope"], (x, p) => p[0] + p[1] * x,
44-
modelHessianDiagonal: (_, _) => throw new TestException());
45-
var parameterConfigurations = new[] { Variable("offset", 1), Variable("slope", 1) };
46-
47-
Action action = () => _minimizer.Minimize(cost, parameterConfigurations);
48-
49-
action.Should().Throw<CostFunctionError>().WithMessage("*Hessian diagonal*").WithInnerException<TestException>();
50-
}
51-
52-
[Test]
53-
public void when_asked_to_minimize_a_cost_function_with_an_analytical_gradient_that_returns_the_wrong_size_throws_a_cost_function_error(
17+
public void when_asked_to_minimize_a_cost_function_with_an_analytical_gradient_that_returns_the_wrong_size_throws_an_exception(
5418
[Values(1, 3)] int flawedGradientSize)
5519
{
5620
var cost = new ModelEvaluatingCostFunction(1, ["offset", "slope"], (x, p) => p[0] + p[1] * x,
@@ -59,7 +23,7 @@ public void when_asked_to_minimize_a_cost_function_with_an_analytical_gradient_t
5923

6024
Action action = () => _minimizer.Minimize(cost, parameterConfigurations);
6125

62-
action.Should().Throw<CostFunctionError>().WithMessage("*gradient*");
26+
action.Should().Throw<InvalidCostFunctionException>().WithMessage("*gradient*");
6327
}
6428

6529
[Test]
@@ -72,7 +36,7 @@ public void when_asked_to_minimize_a_cost_function_with_an_analytical_hessian_of
7236

7337
Action action = () => _minimizer.Minimize(cost, parameterConfigurations);
7438

75-
action.Should().Throw<CostFunctionError>().WithMessage("*Hessian*");
39+
action.Should().Throw<InvalidCostFunctionException>().WithMessage("*Hessian*");
7640
}
7741

7842
[Test]
@@ -85,7 +49,7 @@ public void when_asked_to_minimize_a_cost_function_with_an_analytical_hessian_di
8549

8650
Action action = () => _minimizer.Minimize(cost, parameterConfigurations);
8751

88-
action.Should().Throw<CostFunctionError>().WithMessage("*Hessian diagonal*");
52+
action.Should().Throw<InvalidCostFunctionException>().WithMessage("*Hessian diagonal*");
8953
}
9054

9155
[TestCaseSource(nameof(WellPosedMinimizationProblems))]

minuit2.UnitTests/Least_squares_cost_function_with_batch_evaluation.spec.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class A_least_squares_cost_function_with_batch_evaluation
1212
private static int AnyCount(int min = 10, int max = 100) => Any.Integer().Between(min, max);
1313
private static List<double> AnyValues(int count) => Enumerable.Range(0, count).Select(_ => (double)Any.Double()).ToList();
1414

15-
private static IReadOnlyList<double> TestModel(IReadOnlyList<double> x, IReadOnlyList<double> p) =>
15+
private static double[] TestModel(IReadOnlyList<double> x, IReadOnlyList<double> p) =>
1616
x.Select(xx => p[0] * xx + p[1] * p[1] * xx).ToArray();
1717

1818
public class With_a_uniform_y_error

minuit2.UnitTests/MinimizationProblems/ConfigurableLeastSquaresProblem.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public class LeastSquaresCostBuilder(
4747
true when _hasGradient => CostFunction.LeastSquares(xValues, yValues, yError, _parameterNames, model, modelGradient, _errorDefinitionInSigma),
4848
false when _hasGradient => CostFunction.LeastSquares(xValues, yValues, _parameterNames, model, modelGradient, _errorDefinitionInSigma),
4949
true => CostFunction.LeastSquares(xValues, yValues, yError, _parameterNames, model, _errorDefinitionInSigma),
50-
false => CostFunction.LeastSquares(xValues, yValues, _parameterNames, model, _errorDefinitionInSigma),
50+
false => CostFunction.LeastSquares(xValues, yValues, _parameterNames, model, _errorDefinitionInSigma)
5151
};
5252

5353
public LeastSquaresCostBuilder WithUnknownYErrors()

minuit2.net/CostFunctionDerivativesGuard.cs

Lines changed: 0 additions & 74 deletions
This file was deleted.

minuit2.net/CostFunctionError.cs

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using minuit2.net.CostFunctions;
2+
3+
namespace minuit2.net;
4+
5+
internal static class CostFunctionValidation
6+
{
7+
public static void EnsureValidDerivativeSizes(
8+
ICostFunction costFunction,
9+
IReadOnlyList<double> parameterValues)
10+
{
11+
var exceptions = new List<Exception>();
12+
13+
if (costFunction.HasGradient)
14+
EnsureValidGradientSize(costFunction, parameterValues, exceptions);
15+
if (costFunction.HasHessian)
16+
EnsureValidHessianSize(costFunction, parameterValues, exceptions);
17+
if (costFunction.HasHessianDiagonal)
18+
EnsureValidHessianDiagonalSize(costFunction, parameterValues, exceptions);
19+
20+
if (exceptions.Count == 1)
21+
throw exceptions.Single();
22+
if (exceptions.Count > 1)
23+
throw new AggregateException(exceptions);
24+
}
25+
26+
private static void EnsureValidGradientSize(
27+
ICostFunction costFunction,
28+
IReadOnlyList<double> parameterValues,
29+
List<Exception> exceptions)
30+
{
31+
var size = costFunction.GradientFor(parameterValues).Count;
32+
var expectedSize = costFunction.Parameters.Count;
33+
if (size != expectedSize)
34+
exceptions.Add(new InvalidCostFunctionException(
35+
$"Invalid gradient size: expected {expectedSize} value(s) (one per parameter), but got {size}."));
36+
}
37+
38+
private static void EnsureValidHessianSize(
39+
ICostFunction costFunction,
40+
IReadOnlyList<double> parameterValues,
41+
List<Exception> exceptions)
42+
{
43+
var size = costFunction.HessianFor(parameterValues).Count;
44+
var expectedSize = costFunction.Parameters.Count * costFunction.Parameters.Count;
45+
if (size != expectedSize)
46+
exceptions.Add(new InvalidCostFunctionException(
47+
$"Invalid Hessian size: expected {expectedSize} value(s) (one per parameter pair), but got {size}."));
48+
}
49+
50+
private static void EnsureValidHessianDiagonalSize(
51+
ICostFunction costFunction,
52+
IReadOnlyList<double> parameterValues,
53+
List<Exception> exceptions)
54+
{
55+
var size = costFunction.HessianDiagonalFor(parameterValues).Count;
56+
var expectedSize = costFunction.Parameters.Count;
57+
if (size != expectedSize)
58+
exceptions.Add(new InvalidCostFunctionException(
59+
$"Invalid Hessian diagonal size: expected {expectedSize} value(s) (one per parameter), but got {size}."));
60+
}
61+
}

minuit2.net/CostFunctions/CostFunction.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using static minuit2.net.CostFunctions.DataPointGuard;
2-
31
namespace minuit2.net.CostFunctions;
42

53
public static class CostFunction
@@ -92,7 +90,7 @@ private static LeastSquares LeastSquaresWithUnknownYError(
9290
Func<double, IReadOnlyList<double>, IReadOnlyList<double>>? modelHessianDiagonal,
9391
double errorDefinitionInSigma)
9492
{
95-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)));
93+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)));
9694
return new LeastSquares(
9795
x,
9896
y,
@@ -203,7 +201,7 @@ private static LeastSquares LeastSquaresWithUniformYError(
203201
Func<double, IReadOnlyList<double>, IReadOnlyList<double>>? modelHessianDiagonal,
204202
double errorDefinitionInSigma)
205203
{
206-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)));
204+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)));
207205
return new LeastSquares(
208206
x,
209207
y,
@@ -314,7 +312,7 @@ private static LeastSquares LeastSquaresWithIndividualYErrors(
314312
Func<double, IReadOnlyList<double>, IReadOnlyList<double>>? modelHessianDiagonal,
315313
double errorDefinitionInSigma)
316314
{
317-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)), (yError, nameof(yError)));
315+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)), (yError, nameof(yError)));
318316
return new LeastSquares(
319317
x,
320318
y,
@@ -335,7 +333,7 @@ public static ICostFunction LeastSquares(
335333
Func<IReadOnlyList<double>, IReadOnlyList<double>, IReadOnlyList<double>> model,
336334
double errorDefinitionInSigma = 1)
337335
{
338-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)));
336+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)));
339337
return new LeastSquaresWithBatchEvaluationModel(
340338
x,
341339
y,
@@ -354,7 +352,7 @@ public static ICostFunction LeastSquares(
354352
Func<IReadOnlyList<double>, IReadOnlyList<double>, IReadOnlyList<double>> model,
355353
double errorDefinitionInSigma = 1)
356354
{
357-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)));
355+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)));
358356
return new LeastSquaresWithBatchEvaluationModel(
359357
x,
360358
y,
@@ -373,7 +371,7 @@ public static ICostFunction LeastSquares(
373371
Func<IReadOnlyList<double>, IReadOnlyList<double>, IReadOnlyList<double>> model,
374372
double errorDefinitionInSigma = 1)
375373
{
376-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)), (yError, nameof(yError)));
374+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)), (yError, nameof(yError)));
377375
return new LeastSquaresWithBatchEvaluationModel(
378376
x,
379377
y,
@@ -392,7 +390,7 @@ public static ICostFunction LeastSquaresWithGaussNewtonApproximation(
392390
Func<double, IReadOnlyList<double>, IReadOnlyList<double>> modelGradient,
393391
double errorDefinitionInSigma = 1)
394392
{
395-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)));
393+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)));
396394
return new LeastSquaresWithGaussNewtonApproximation(
397395
x,
398396
y,
@@ -413,7 +411,7 @@ public static ICostFunction LeastSquaresWithGaussNewtonApproximation(
413411
Func<double, IReadOnlyList<double>, IReadOnlyList<double>> modelGradient,
414412
double errorDefinitionInSigma = 1)
415413
{
416-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)));
414+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)));
417415
return new LeastSquaresWithGaussNewtonApproximation(
418416
x,
419417
y,
@@ -434,7 +432,7 @@ public static ICostFunction LeastSquaresWithGaussNewtonApproximation(
434432
Func<double, IReadOnlyList<double>, IReadOnlyList<double>> modelGradient,
435433
double errorDefinitionInSigma = 1)
436434
{
437-
ThrowIfCountMismatchBetween((x, nameof(x)), (y, nameof(y)), (yError, nameof(yError)));
435+
DataValidation.EnsureMatchingSizesBetween((x, nameof(x)), (y, nameof(y)), (yError, nameof(yError)));
438436
return new LeastSquaresWithGaussNewtonApproximation(
439437
x,
440438
y,

minuit2.net/CppException.cs

Lines changed: 0 additions & 3 deletions
This file was deleted.

minuit2.net/CostFunctions/DataPointGuard.cs renamed to minuit2.net/DataValidation.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
namespace minuit2.net.CostFunctions;
1+
namespace minuit2.net;
22
using NamedValues = (IReadOnlyList<double> Values, string Name);
33

4-
internal static class DataPointGuard
4+
internal static class DataValidation
55
{
6-
public static void ThrowIfCountMismatchBetween(NamedValues reference, params NamedValues[] others)
6+
public static void EnsureMatchingSizesBetween(NamedValues reference, params NamedValues[] others)
77
{
88
var exceptions = new List<Exception>();
99
foreach (var other in others)

minuit2.net/HesseErrorCalculator.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using minuit2.net.CostFunctions;
2-
using static minuit2.net.ParameterMappingGuard;
32

43
namespace minuit2.net;
54

@@ -11,7 +10,7 @@ public static IMinimizationResult Refine(
1110
Strategy strategy = Strategy.Balanced,
1211
CancellationToken cancellationToken = default)
1312
{
14-
ThrowIfNoUniqueMappingBetween(
13+
ParameterValidation.EnsureUniqueMappingBetween(
1514
costFunction.Parameters,
1615
result.Parameters,
1716
"minimization result",
@@ -32,6 +31,6 @@ public static IMinimizationResult Refine(
3231

3332
return success
3433
? new MinimizationResult(minimum, costFunction)
35-
: throw new CppException();
34+
: throw new NativeMinuit2Exception();
3635
}
3736
}

0 commit comments

Comments
 (0)