Skip to content

Commit bfe2606

Browse files
authored
LogLikelihood function, Newton optimizer: dtype passed (#242)
1 parent 989c5c5 commit bfe2606

File tree

53 files changed

+171
-63
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+171
-63
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## 16.17.3
4+
- Log Likelihood Cost function:
5+
- `dtype` passed
6+
- Newton optimizer:
7+
- `dtype` passed
8+
- Removed `package:ml_linalg/linalg.dart` and `package:ml_algo/ml_algo.dart` imports
9+
310
## 16.17.2
411
- Code quality:
512
- Strict options turned on

benchmark/decision_tree_classifier.dart

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import 'dart:math';
55
import 'package:benchmark_harness/benchmark_harness.dart';
66
import 'package:ml_algo/ml_algo.dart';
77
import 'package:ml_dataframe/ml_dataframe.dart';
8-
import 'package:ml_linalg/linalg.dart';
8+
import 'package:ml_linalg/matrix.dart';
9+
import 'package:ml_linalg/vector.dart';
910

1011
const observationsNum = 300;
1112
const columnsNum = 11;

benchmark/kd_tree/kd_tree_querying.dart

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ import 'package:benchmark_harness/benchmark_harness.dart';
33
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
44
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
55
import 'package:ml_dataframe/ml_dataframe.dart';
6-
import 'package:ml_linalg/linalg.dart';
6+
import 'package:ml_linalg/matrix.dart';
7+
import 'package:ml_linalg/vector.dart';
78

89
final k = 10;
910

benchmark/random_binary_projection_searcher/searcher_querying.dart

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ import 'package:benchmark_harness/benchmark_harness.dart';
33
import 'package:ml_algo/ml_algo.dart';
44
import 'package:ml_algo/src/retrieval/random_binary_projection_searcher/random_binary_projection_searcher_impl.dart';
55
import 'package:ml_dataframe/ml_dataframe.dart';
6-
import 'package:ml_linalg/linalg.dart';
6+
import 'package:ml_linalg/matrix.dart';
7+
import 'package:ml_linalg/vector.dart';
78

89
final k = 10;
910
final digitCapacity = 10;

lib/src/classifier/_helpers/create_log_likelihood_optimizer.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ LinearOptimizer createLogLikelihoodOptimizer(
5151
linkFunction: linkFunction,
5252
positiveLabel: positiveLabel,
5353
negativeLabel: negativeLabel,
54+
dtype: dtype,
5455
);
5556
final normalizedLabels =
5657
normalizeClassLabels(labels, positiveLabel, negativeLabel);
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import 'package:ml_algo/src/cost_function/cost_function.dart';
22
import 'package:ml_algo/src/cost_function/cost_function_type.dart';
33
import 'package:ml_algo/src/link_function/link_function.dart';
4+
import 'package:ml_linalg/dtype.dart';
45

56
abstract class CostFunctionFactory {
67
CostFunction createByType(
78
CostFunctionType type, {
89
LinkFunction? linkFunction,
910
num? positiveLabel,
1011
num? negativeLabel,
12+
DType dtype,
1113
});
1214
}

lib/src/cost_function/cost_function_factory_impl.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import 'package:ml_algo/src/cost_function/cost_function_type.dart';
44
import 'package:ml_algo/src/cost_function/least_square_cost_function.dart';
55
import 'package:ml_algo/src/cost_function/log_likelihood_cost_function.dart';
66
import 'package:ml_algo/src/link_function/link_function.dart';
7+
import 'package:ml_linalg/dtype.dart';
78

89
class CostFunctionFactoryImpl implements CostFunctionFactory {
910
const CostFunctionFactoryImpl();
@@ -14,6 +15,7 @@ class CostFunctionFactoryImpl implements CostFunctionFactory {
1415
LinkFunction? linkFunction,
1516
num? positiveLabel,
1617
num? negativeLabel,
18+
DType dtype = DType.float32,
1719
}) {
1820
switch (type) {
1921
case CostFunctionType.logLikelihood:
@@ -34,6 +36,7 @@ class CostFunctionFactoryImpl implements CostFunctionFactory {
3436
linkFunction,
3537
positiveLabel,
3638
negativeLabel,
39+
dtype,
3740
);
3841

3942
case CostFunctionType.leastSquare:

lib/src/cost_function/least_square_cost_function.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import 'package:ml_algo/src/cost_function/cost_function.dart';
2-
import 'package:ml_linalg/linalg.dart';
2+
import 'package:ml_linalg/matrix.dart';
33

44
class LeastSquareCostFunction implements CostFunction {
55
const LeastSquareCostFunction();

lib/src/cost_function/log_likelihood_cost_function.dart

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@ import 'package:ml_algo/src/cost_function/cost_function.dart';
22
import 'package:ml_algo/src/helpers/normalize_class_labels.dart';
33
import 'package:ml_algo/src/helpers/validate_class_labels.dart';
44
import 'package:ml_algo/src/link_function/link_function.dart';
5-
import 'package:ml_linalg/linalg.dart';
5+
import 'package:ml_linalg/dtype.dart';
6+
import 'package:ml_linalg/matrix.dart';
7+
import 'package:ml_linalg/vector.dart';
68

79
class LogLikelihoodCostFunction implements CostFunction {
8-
LogLikelihoodCostFunction(
9-
this._linkFunction, this._positiveLabel, this._negativeLabel) {
10+
LogLikelihoodCostFunction(this._linkFunction, this._positiveLabel,
11+
this._negativeLabel, this._dtype) {
1012
validateClassLabels(_positiveLabel, _negativeLabel);
1113
}
1214

1315
final LinkFunction _linkFunction;
1416
final num _positiveLabel;
1517
final num _negativeLabel;
18+
final DType _dtype;
1619

1720
@override
1821
double getCost(Matrix x, Matrix w, Matrix y) {
@@ -39,9 +42,9 @@ class LogLikelihoodCostFunction implements CostFunction {
3942
@override
4043
Matrix getHessian(Matrix x, Matrix w, Matrix y) {
4144
final prediction = _linkFunction.link(x * w).toVector();
42-
final ones = Vector.filled(x.rowsNum, 1.0, dtype: x.dtype);
45+
final ones = Vector.filled(x.rowsNum, 1.0, dtype: _dtype);
4346
final V = Matrix.diagonal((prediction * (ones - prediction)).toList(),
44-
dtype: x.dtype);
47+
dtype: _dtype);
4548

4649
return x.transpose() * V * x;
4750
}

lib/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_generator.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import 'package:ml_linalg/linalg.dart';
1+
import 'package:ml_linalg/vector.dart';
22

33
abstract class InitialCoefficientsGenerator {
44
Vector generate(int length);

0 commit comments

Comments
 (0)