@@ -2,17 +2,20 @@ import 'package:ml_algo/src/cost_function/cost_function.dart';
22import 'package:ml_algo/src/helpers/normalize_class_labels.dart' ;
33import 'package:ml_algo/src/helpers/validate_class_labels.dart' ;
44import 'package:ml_algo/src/link_function/link_function.dart' ;
5- import 'package:ml_linalg/linalg.dart' ;
5+ import 'package:ml_linalg/dtype.dart' ;
6+ import 'package:ml_linalg/matrix.dart' ;
7+ import 'package:ml_linalg/vector.dart' ;
68
79class LogLikelihoodCostFunction implements CostFunction {
8- LogLikelihoodCostFunction (
9- this ._linkFunction , this ._positiveLabel, this ._negativeLabel ) {
10+ LogLikelihoodCostFunction (this ._linkFunction, this ._positiveLabel,
11+ this ._negativeLabel , this ._dtype ) {
1012 validateClassLabels (_positiveLabel, _negativeLabel);
1113 }
1214
1315 final LinkFunction _linkFunction;
1416 final num _positiveLabel;
1517 final num _negativeLabel;
18+ final DType _dtype;
1619
1720 @override
1821 double getCost (Matrix x, Matrix w, Matrix y) {
@@ -39,9 +42,9 @@ class LogLikelihoodCostFunction implements CostFunction {
3942 @override
4043 Matrix getHessian (Matrix x, Matrix w, Matrix y) {
4144 final prediction = _linkFunction.link (x * w).toVector ();
42- final ones = Vector .filled (x.rowsNum, 1.0 , dtype: x.dtype );
45+ final ones = Vector .filled (x.rowsNum, 1.0 , dtype: _dtype );
4346 final V = Matrix .diagonal ((prediction * (ones - prediction)).toList (),
44- dtype: x.dtype );
47+ dtype: _dtype );
4548
4649 return x.transpose () * V * x;
4750 }
0 commit comments