Skip to content

Commit 0a3b8dd

Browse files
committed
Port "Two changes in newton.cpp"
Ported from cjlin1/liblinear@07208b8
1 parent ead90cf commit 0a3b8dd

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/main/java/de/bwaldvogel/liblinear/Newton.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ void newton(double[] w) {
4747
double gnorm0 = Blas.dnrm2_(n, g, inc);
4848

4949
f = fun_obj.fun(w);
50-
info("init f %5.3e%n", f);
5150
fun_obj.grad(w, g);
5251
double gnorm = Blas.dnrm2_(n, g, inc);
52+
info("init f %5.3e |g| %5.3e%n", f, gnorm);
5353

5454
if (gnorm <= eps * gnorm0)
5555
search = false;
@@ -70,24 +70,24 @@ void newton(double[] w) {
7070
break;
7171
}
7272

73-
info("iter %2d f %5.3e |g| %5.3e CG %3d step_size %4.2e%n", iter, f, gnorm, cg_iter, step_size);
74-
75-
actred = fold - f;
76-
iter++;
77-
7873
fun_obj.grad(w, g);
79-
8074
gnorm = Blas.dnrm2_(n, g, inc);
75+
76+
info("iter %2d f %5.3e |g| %5.3e CG %3d step_size %4.2e%n", iter, f, gnorm, cg_iter, step_size);
77+
8178
if (gnorm <= eps * gnorm0)
8279
break;
8380
if (f < -1.0e+32) {
8481
info("WARNING: f < -1.0e+32%n");
8582
break;
8683
}
84+
actred = fold - f;
8785
if (Math.abs(actred) <= 1.0e-12 * Math.abs(f)) {
8886
info("WARNING: actred too small%n");
8987
break;
9088
}
89+
90+
iter++;
9191
}
9292
}
9393

@@ -97,7 +97,7 @@ private int pcg(double[] g, double[] M, double[] s, double[] r) {
9797
double one = 1;
9898
double[] d = new double[n];
9999
double[] Hd = new double[n];
100-
double zTr, znewTrnew, alpha, beta, cgtol;
100+
double zTr, znewTrnew, alpha, beta, cgtol, dHd;
101101
double[] z = new double[n];
102102
double Q = 0, newQ, Qdiff;
103103

@@ -116,9 +116,14 @@ private int pcg(double[] g, double[] M, double[] s, double[] r) {
116116

117117
while (cg_iter < max_cg_iter) {
118118
cg_iter++;
119+
119120
fun_obj.Hv(d, Hd);
121+
dHd = Blas.ddot_(n, d, inc, Hd, inc);
122+
// avoid 0/0 in getting alpha
123+
if (dHd <= 1.0e-16)
124+
break;
120125

121-
alpha = zTr / Blas.ddot_(n, d, inc, Hd, inc);
126+
alpha = zTr / dHd;
122127
Blas.daxpy_(n, alpha, d, inc, s, inc);
123128
alpha = -alpha;
124129
Blas.daxpy_(n, alpha, Hd, inc, r, inc);

src/test/java/de/bwaldvogel/liblinear/LinearTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ void testFindBestParametersOnSpliceDataSet_L2R_L2LOSS_SVR() throws Exception {
741741
Parameter param = new Parameter(L2R_L2LOSS_SVR, 1, 0.001, 0.1);
742742
ParameterSearchResult result = Linear.findParameters(problem, param, 5, -1, -1);
743743
assertThat(result.getBestC()).isEqualTo(0.00390625);
744-
assertThat(result.getBestScore()).isEqualTo(0.5699400237604259, Offset.offset(0.0000001));
744+
assertThat(result.getBestScore()).isEqualTo(0.5699399182191544, Offset.offset(0.0000001));
745745
assertThat(result.getBestP()).isEqualTo(0.0);
746746
}
747747

0 commit comments

Comments
 (0)