Skip to content

Commit bb15308

Browse files
committed
Make PRNG configurable via Parameter #40
1 parent 85f2747 commit bb15308

File tree

8 files changed

+51
-20
lines changed

8 files changed

+51
-20
lines changed

src/main/java/de/bwaldvogel/liblinear/Linear.java

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ public class Linear {
4343
private static final Object OUTPUT_MUTEX = new Object();
4444
private static PrintStream DEBUG_OUTPUT = System.out;
4545

46-
private static final long DEFAULT_RANDOM_SEED = 0L;
47-
static Random random = new Random(DEFAULT_RANDOM_SEED);
48-
4946
/**
5047
* @param target predicted classes
5148
*/
@@ -63,7 +60,7 @@ public static void crossValidation(Problem prob, Parameter param, int nr_fold, d
6360
for (i = 0; i < l; i++)
6461
perm[i] = i;
6562
for (i = 0; i < l; i++) {
66-
int j = i + random.nextInt(l - i);
63+
int j = i + param.random.nextInt(l - i);
6764
swap(perm, i, j);
6865
}
6966
for (i = 0; i <= nr_fold; i++)
@@ -115,7 +112,7 @@ public static ParameterSearchResult findParameters(Problem prob, Parameter param
115112
for (i = 0; i < l; i++)
116113
perm[i] = i;
117114
for (i = 0; i < l; i++) {
118-
int j = i + random.nextInt(l - i);
115+
int j = i + param.random.nextInt(l - i);
119116
swap(perm, i, j);
120117
}
121118
for (i = 0; i <= nr_fold; i++)
@@ -694,7 +691,7 @@ private static int solve_l2r_l1l2_svc(Problem prob, Parameter param, double[] w,
694691
PGmin_new = Double.POSITIVE_INFINITY;
695692

696693
for (i = 0; i < active_size; i++) {
697-
int j = i + random.nextInt(active_size - i);
694+
int j = i + param.random.nextInt(active_size - i);
698695
swap(index, i, j);
699696
}
700697

@@ -867,7 +864,7 @@ private static int solve_l2r_l1l2_svr(Problem prob, Parameter param, double[] w,
867864
Gnorm1_new = 0;
868865

869866
for (i = 0; i < active_size; i++) {
870-
int j = i + random.nextInt(active_size - i);
867+
int j = i + param.random.nextInt(active_size - i);
871868
swap(index, i, j);
872869
}
873870

@@ -1043,7 +1040,7 @@ private static int solve_l2r_lr_dual(Problem prob, Parameter param, double[] w,
10431040

10441041
while (iter < max_iter) {
10451042
for (i = 0; i < l; i++) {
1046-
int j = i + random.nextInt(l - i);
1043+
int j = i + param.random.nextInt(l - i);
10471044
swap(index, i, j);
10481045
}
10491046
int newton_iter = 0;
@@ -1206,7 +1203,7 @@ private static int solve_l1r_l2_svc(Problem prob_col, Parameter param, double[]
12061203
Gnorm1_new = 0;
12071204

12081205
for (j = 0; j < active_size; j++) {
1209-
int i = j + random.nextInt(active_size - j);
1206+
int i = j + param.random.nextInt(active_size - j);
12101207
swap(index, i, j);
12111208
}
12121209

@@ -1556,7 +1553,7 @@ else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) {
15561553
QP_Gnorm1_new = 0;
15571554

15581555
for (j = 0; j < QP_active_size; j++) {
1559-
int i = random.nextInt(QP_active_size - j);
1556+
int i = param.random.nextInt(QP_active_size - j);
15601557
swap(index, i, j);
15611558
}
15621559

@@ -2190,7 +2187,7 @@ public static Model train(Problem prob, Parameter param) {
21902187
}
21912188
}
21922189

2193-
SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps);
2190+
SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps, param.random);
21942191
solver.solve(model.w);
21952192
} else {
21962193
if (nr_class == 2) {
@@ -2537,9 +2534,8 @@ public static int getVersion() {
25372534
/**
25382535
* resets the PRNG
25392536
*
2540-
* this is i.a. needed for regression testing (eg. the Weka wrapper)
2537+
* @deprecated Use {@link Parameter#setRandom(Random)} instead
25412538
*/
25422539
public static void resetRandom() {
2543-
random = new Random(DEFAULT_RANDOM_SEED);
25442540
}
25452541
}

src/main/java/de/bwaldvogel/liblinear/Parameter.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
package de.bwaldvogel.liblinear;
22

3+
import java.io.ByteArrayInputStream;
4+
import java.io.ByteArrayOutputStream;
5+
import java.io.IOException;
6+
import java.io.ObjectInputStream;
7+
import java.io.ObjectOutputStream;
38
import java.util.Arrays;
9+
import java.util.Random;
410

511

612
public final class Parameter implements Cloneable {
713

14+
private static final long DEFAULT_RANDOM_SEED = 0L;
15+
816
double C;
917

1018
/** stopping tolerance */
@@ -29,6 +37,8 @@ public final class Parameter implements Cloneable {
2937

3038
boolean regularize_bias = true;
3139

40+
Random random = new Random(DEFAULT_RANDOM_SEED);
41+
3242
public Parameter(SolverType solver, double C, double eps) {
3343
setSolverType(solver);
3444
setC(C);
@@ -206,6 +216,10 @@ public boolean isRegularizeBias() {
206216
return regularize_bias;
207217
}
208218

219+
public void setRandom(Random random) {
220+
this.random = random;
221+
}
222+
209223
@Override
210224
public Parameter clone() {
211225
Parameter clone = new Parameter(solverType, C, eps, max_iters, p);
@@ -215,7 +229,23 @@ public Parameter clone() {
215229
clone.p = p;
216230
clone.nu = nu;
217231
clone.regularize_bias = regularize_bias;
232+
clone.random = deepClone(random);
218233
return clone;
219234
}
220235

236+
private static Random deepClone(Random random) {
237+
try {
238+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
239+
ObjectOutputStream out = new ObjectOutputStream(baos);) {
240+
out.writeObject(random);
241+
try (ByteArrayInputStream bis = new ByteArrayInputStream(baos.toByteArray());
242+
ObjectInputStream in = new ObjectInputStream(bis)) {
243+
return (Random)in.readObject();
244+
}
245+
}
246+
} catch (IOException | ClassNotFoundException e) {
247+
throw new RuntimeException("Failed to clone " + random, e);
248+
}
249+
}
250+
221251
}

src/main/java/de/bwaldvogel/liblinear/SolverMCSVM_CS.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static de.bwaldvogel.liblinear.Linear.*;
44

55
import java.util.Arrays;
6+
import java.util.Random;
67

78

89
/**
@@ -38,12 +39,14 @@ class SolverMCSVM_CS {
3839
private final int w_size, l;
3940
private final int nr_class;
4041
private final Problem prob;
42+
private final Random random;
4143

42-
public SolverMCSVM_CS(Problem prob, int nr_class, double[] C, double eps) {
44+
public SolverMCSVM_CS(Problem prob, int nr_class, double[] C, double eps, Random random) {
4345
this.w_size = prob.n;
4446
this.l = prob.l;
4547
this.nr_class = nr_class;
4648
this.eps = eps;
49+
this.random = random;
4750
this.max_iter = 100000;
4851
this.prob = prob;
4952
this.C = C;
@@ -115,7 +118,7 @@ public void solve(double[] w) {
115118

116119
for (i = 0; i < active_size; i++) {
117120
// int j = i+rand()%(active_size-i);
118-
int j = i + Linear.random.nextInt(active_size - i);
121+
int j = i + random.nextInt(active_size - i);
119122
swap(index, i, j);
120123
}
121124
for (s = 0; s < active_size; s++) {

src/test/java/de/bwaldvogel/liblinear/LinearTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class LinearTest {
3030

3131
@BeforeEach
3232
public void reset() throws Exception {
33-
Linear.resetRandom();
3433
Linear.disableDebugOutput();
3534
}
3635

src/test/java/de/bwaldvogel/liblinear/ParameterTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import static de.bwaldvogel.liblinear.SolverType.*;
44
import static org.assertj.core.api.Assertions.*;
55

6+
import java.util.Random;
7+
68
import org.junit.jupiter.api.BeforeEach;
79
import org.junit.jupiter.api.Test;
810

@@ -181,6 +183,8 @@ void testClone_Simple() throws Exception {
181183
void testClone_Full() throws Exception {
182184
Parameter parameter = new Parameter(L1R_LR, 123.456, 0.123, 9000, 1.2);
183185
parameter.setWeights(new double[] {1, 2}, new int[] {3, 4});
186+
Random random = new Random(123);
187+
parameter.setRandom(random);
184188
Parameter clone = parameter.clone();
185189
assertThat(clone.getSolverType()).isEqualTo(L1R_LR);
186190
assertThat(clone.getC()).isEqualTo(123.456);
@@ -190,6 +194,9 @@ void testClone_Full() throws Exception {
190194
assertThat(clone.getWeights()).containsExactly(1, 2);
191195
assertThat(clone.getWeightLabels()).containsExactly(3, 4);
192196
assertThat(clone.getNumWeights()).isEqualTo(2);
197+
198+
assertThat(clone.random).isNotSameAs(random);
199+
assertThat(random.nextInt()).isEqualTo(clone.random.nextInt());
193200
}
194201

195202
}

src/test/java/de/bwaldvogel/liblinear/PredictTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class PredictTest {
2828

2929
@BeforeEach
3030
public void setUp() {
31-
Linear.resetRandom();
3231
Linear.setDebugOutput(new PrintStream(byteArrayOutputStream));
3332
assertThat(testModel.getNrClass()).isGreaterThanOrEqualTo(2);
3433
assertThat(testModel.getNrFeature()).isGreaterThanOrEqualTo(10);

src/test/java/de/bwaldvogel/liblinear/RegressionTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ public String toString() {
210210
@MethodSource("data")
211211
void regressionTest(TestParams params) throws Exception {
212212
log.info("Running regression test for '{}'", params);
213-
Linear.resetRandom();
214213
Path trainingFile = Paths.get("src/test/datasets", params.dataset, params.dataset);
215214
Problem problem = Train.readProblem(trainingFile, params.bias);
216215
Parameter parameter = new Parameter(params.solverType, 1, 0.1);
@@ -292,7 +291,6 @@ void regressionTest(TestParams params) throws Exception {
292291

293292
@Test
294293
void testOneClass(@TempDir Path tempDir) throws Exception {
295-
Linear.resetRandom();
296294
Path trainingFile = Paths.get("src/test/datasets/splice/splice");
297295

298296
Path spliceClass1 = tempDir.resolve("splice-class-1");

src/test/java/de/bwaldvogel/liblinear/TrainTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class TrainTest {
2121

2222
@BeforeEach
2323
public void reset() throws Exception {
24-
Linear.resetRandom();
2524
Linear.disableDebugOutput();
2625
}
2726

0 commit comments

Comments
 (0)