Skip to content

Commit e47e643

Browse files
committed
[MINOR] Fix proper seed handling in sparsity estimator LayeredGraph
In order to fix spurious test failures for specific seeds, we now fix the seeds in the respective tests, and implement proper seed handling in the sparsity estimator LayeredGraph.
1 parent 11c5a74 commit e47e643

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.ArrayList;
3333
import java.util.Arrays;
3434
import java.util.List;
35+
import java.util.Random;
3536
import java.util.stream.Collectors;
3637
import java.util.stream.Stream;
3738

@@ -44,15 +45,21 @@
4445
*/
4546
public class EstimatorLayeredGraph extends SparsityEstimator {
4647

47-
private static final int ROUNDS = 512;
48+
public static final int ROUNDS = 512;
4849
private final int _rounds;
50+
private final Random _seeds;
4951

5052
public EstimatorLayeredGraph() {
5153
this(ROUNDS);
5254
}
5355

5456
public EstimatorLayeredGraph(int rounds) {
57+
this(rounds, (int)System.currentTimeMillis());
58+
}
59+
60+
public EstimatorLayeredGraph(int rounds, int seed) {
5561
_rounds = rounds;
62+
_seeds = new Random(seed);
5663
}
5764

5865
@Override
@@ -73,9 +80,9 @@ public LayeredGraph traverse(MMNode node) {
7380
LayeredGraph ret, left, right;
7481

7582
left = (node.getLeft().getData() == null)
76-
? retL : new LayeredGraph(node.getLeft().getData(), _rounds);
83+
? retL : new LayeredGraph(node.getLeft().getData(), _rounds, _seeds.nextInt());
7784
right = (node.getRight().getData() == null)
78-
? retR : new LayeredGraph(node.getRight().getData(), _rounds);
85+
? retR : new LayeredGraph(node.getRight().getData(), _rounds, _seeds.nextInt());
7986

8087
ret = estimInternal(left, right, node.getOp());
8188

@@ -86,24 +93,24 @@ public LayeredGraph traverse(MMNode node) {
8693
public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
8794
if( op == OpCode.MM )
8895
return estim(m1, m2);
89-
LayeredGraph lg1 = new LayeredGraph(m1, _rounds);
90-
LayeredGraph lg2 = new LayeredGraph(m2, _rounds);
96+
LayeredGraph lg1 = new LayeredGraph(m1, _rounds, _seeds.nextInt());
97+
LayeredGraph lg2 = new LayeredGraph(m2, _rounds, _seeds.nextInt());
9198
LayeredGraph output = estimInternal(lg1, lg2, op);
9299
return OptimizerUtils.getSparsity(
93100
output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz());
94101
}
95102

96103
@Override
97104
public double estim(MatrixBlock m, OpCode op) {
98-
LayeredGraph lg1 = new LayeredGraph(m, _rounds);
105+
LayeredGraph lg1 = new LayeredGraph(m, _rounds, _seeds.nextInt());
99106
LayeredGraph output = estimInternal(lg1, null, op);
100107
return OptimizerUtils.getSparsity(
101108
output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz());
102109
}
103110

104111
@Override
105112
public double estim(MatrixBlock m1, MatrixBlock m2) {
106-
LayeredGraph graph = new LayeredGraph(Arrays.asList(m1,m2), _rounds);
113+
LayeredGraph graph = new LayeredGraph(Arrays.asList(m1,m2), _rounds, _seeds.nextInt());
107114
return OptimizerUtils.getSparsity(
108115
m1.getNumRows(), m2.getNumColumns(), graph.estimateNnz());
109116
}
@@ -153,16 +160,21 @@ private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
153160
public static class LayeredGraph {
154161
private final List<Node[]> _nodes; //nodes partitioned by graph level
155162
private final int _rounds; //length of propagated r-vectors
163+
private final Random _seeds;
156164

157-
public LayeredGraph(List<MatrixBlock> chain, int r) {
165+
public LayeredGraph(int r, int seed) {
158166
_nodes = new ArrayList<>();
159167
_rounds = r;
168+
_seeds = new Random(seed);
169+
}
170+
171+
public LayeredGraph(List<MatrixBlock> chain, int r, int seed) {
172+
this(r, seed);
160173
chain.forEach(i -> buildNext(i));
161174
}
162175

163-
public LayeredGraph(MatrixBlock m, int r) {
164-
_nodes = new ArrayList<>();
165-
_rounds = r;
176+
public LayeredGraph(MatrixBlock m, int r, int seed) {
177+
this(r, seed);
166178
buildNext(m);
167179
}
168180

@@ -215,7 +227,8 @@ public void buildNext(MatrixBlock mb) {
215227
public long estimateNnz() {
216228
//step 1: assign random vectors ~exp(lambda=1) to all leaf nodes
217229
//(lambda is not the mean, if lambda is 2 mean is 1/2)
218-
ExponentialDistribution random = new ExponentialDistribution(new Well1024a(), 1);
230+
ExponentialDistribution random = new ExponentialDistribution(
231+
new Well1024a(_seeds.nextInt()), 1);
219232
for( Node n : _nodes.get(0) ) {
220233
double[] rvect = new double[_rounds];
221234
for (int g = 0; g < _rounds; g++)
@@ -234,7 +247,7 @@ private static double calcNNZ(double[] inpvec, int rounds) {
234247
}
235248

236249
public LayeredGraph rbind(LayeredGraph lg) {
237-
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
250+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds, _seeds.nextInt());
238251

239252
Node[] rows = new Node[_nodes.get(0).length + lg._nodes.get(0).length];
240253
Node[] columns = _nodes.get(1).clone();
@@ -258,7 +271,7 @@ public LayeredGraph rbind(LayeredGraph lg) {
258271
}
259272

260273
public LayeredGraph cbind(LayeredGraph lg) {
261-
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
274+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds, _seeds.nextInt());
262275
int colLength = _nodes.get(1).length + lg._nodes.get(1).length;
263276

264277
Node[] rows = _nodes.get(0).clone();
@@ -286,11 +299,11 @@ public LayeredGraph matMult(LayeredGraph lg) {
286299
List<MatrixBlock> m = Stream.concat(
287300
this.toMatrixBlockList().stream(), lg.toMatrixBlockList().stream())
288301
.collect(Collectors.toList());
289-
return new LayeredGraph(m, _rounds);
302+
return new LayeredGraph(m, _rounds, _seeds.nextInt());
290303
}
291304

292305
public LayeredGraph or(LayeredGraph lg) {
293-
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
306+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds, _seeds.nextInt());
294307
Node[] rows = new Node[_nodes.get(0).length];
295308
for (int i = 0; i < _nodes.get(0).length; i++)
296309
rows[i] = new Node();
@@ -319,7 +332,7 @@ public LayeredGraph or(LayeredGraph lg) {
319332
}
320333

321334
public LayeredGraph and(LayeredGraph lg) {
322-
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
335+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds, _seeds.nextInt());
323336
Node[] rows = new Node[_nodes.get(0).length];
324337
for (int i = 0; i < _nodes.get(0).length; i++)
325338
rows[i] = new Node();
@@ -348,7 +361,7 @@ public LayeredGraph and(LayeredGraph lg) {
348361
}
349362

350363
public LayeredGraph transpose() {
351-
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
364+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds, _seeds.nextInt());
352365
Node[] rows = new Node[_nodes.get(_nodes.size() - 1).length];
353366
for (int i = 0; i < rows.length; i++)
354367
rows[i] = new Node();
@@ -377,7 +390,7 @@ public LayeredGraph transpose() {
377390
}
378391

379392
public LayeredGraph diag() {
380-
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
393+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds, _seeds.nextInt());
381394
Node[] rowsOld = _nodes.get(0);
382395
Node[] columnsOld = _nodes.get(1);
383396

src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,16 @@ public void testBitsetCasecbind() {
116116
//Layered Graph
117117
@Test
118118
public void testLGCaserbind() {
119-
runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, rbind);
119+
runSparsityEstimateTest(
120+
new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 7),
121+
m, k, n, sparsity, rbind);
120122
}
121123

122124
@Test
123125
public void testLGCasecbind() {
124-
runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, cbind);
126+
runSparsityEstimateTest(
127+
new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 3),
128+
m, k, n, sparsity, cbind);
125129
}
126130

127131

0 commit comments

Comments
 (0)