|
| 1 | +import 'package:ml_algo/src/cost_function/cost_function.dart'; |
| 2 | +import 'package:ml_algo/src/linear_optimizer/linear_optimizer.dart'; |
| 3 | +import 'package:ml_linalg/matrix.dart'; |
| 4 | +import 'package:xrange/xrange.dart'; |
| 5 | + |
| 6 | +class LeastSquaresNewtonOptimizer implements LinearOptimizer { |
| 7 | + LeastSquaresNewtonOptimizer( |
| 8 | + {required Matrix features, |
| 9 | + required Matrix labels, |
| 10 | + required CostFunction costFunction, |
| 11 | + required int iterationLimit, |
| 12 | + required num minCoefficientsUpdate, |
| 13 | + num lambda = 0}) |
| 14 | + : _features = features, |
| 15 | + _labels = labels, |
| 16 | + _costFunction = costFunction, |
| 17 | + _iterations = integers(0, iterationLimit), |
| 18 | + _minCoefficientsUpdate = minCoefficientsUpdate, |
| 19 | + _lambda = lambda; |
| 20 | + |
| 21 | + final Matrix _features; |
| 22 | + final Matrix _labels; |
| 23 | + final CostFunction _costFunction; |
| 24 | + final Iterable<int> _iterations; |
| 25 | + final List<num> _costPerIteration = []; |
| 26 | + final num _lambda; |
| 27 | + final num _minCoefficientsUpdate; |
| 28 | + |
| 29 | + @override |
| 30 | + List<num> get costPerIteration => _costPerIteration; |
| 31 | + |
| 32 | + @override |
| 33 | + Matrix findExtrema( |
| 34 | + {Matrix? initialCoefficients, |
| 35 | + bool isMinimizingObjective = true, |
| 36 | + bool collectLearningData = false}) { |
| 37 | + var dtype = _features.dtype; |
| 38 | + var coefficients = initialCoefficients ?? |
| 39 | + Matrix.column(List.filled(_features.first.length, 0), dtype: dtype); |
| 40 | + var prevCoefficients = coefficients; |
| 41 | + var coefficientsUpdate = double.maxFinite; |
| 42 | + |
| 43 | + final regularizingTerm = |
| 44 | + Matrix.scalar(_lambda.toDouble(), _features.columnsNum, dtype: dtype); |
| 45 | + // Since we perfectly know that Hessian matrix calculation of least squares |
| 46 | + // function doesn't depend on coefficient vector, Hessian matrix will be |
| 47 | + // constant throughout the entire optimization procedure, let's calculate it |
| 48 | + // only once in the beginning of the procedure: |
| 49 | + final hessian = _costFunction.getHessian(_features, coefficients, _labels); |
| 50 | + final regularizedInverseHessian = _lambda == 0 |
| 51 | + ? hessian.inverse() |
| 52 | + : (hessian + regularizingTerm).inverse(); |
| 53 | + |
| 54 | + for (final _ in _iterations) { |
| 55 | + if (coefficientsUpdate.isNaN || |
| 56 | + coefficientsUpdate <= _minCoefficientsUpdate) { |
| 57 | + break; |
| 58 | + } |
| 59 | + |
| 60 | + final gradient = |
| 61 | + _costFunction.getGradient(_features, coefficients, _labels); |
| 62 | + |
| 63 | + coefficients = coefficients - regularizedInverseHessian * gradient; |
| 64 | + coefficientsUpdate = (coefficients - prevCoefficients).norm(); |
| 65 | + prevCoefficients = coefficients; |
| 66 | + |
| 67 | + if (collectLearningData) { |
| 68 | + final cost = _costFunction.getCost(_features, coefficients, _labels); |
| 69 | + |
| 70 | + _costPerIteration.add(cost); |
| 71 | + } |
| 72 | + } |
| 73 | + |
| 74 | + return coefficients; |
| 75 | + } |
| 76 | +} |
0 commit comments