3232import java .util .ArrayList ;
3333import java .util .Arrays ;
3434import java .util .List ;
35+ import java .util .Random ;
3536import java .util .stream .Collectors ;
3637import java .util .stream .Stream ;
3738
4445 */
4546public class EstimatorLayeredGraph extends SparsityEstimator {
4647
47- private static final int ROUNDS = 512 ;
48+ public static final int ROUNDS = 512 ;
4849 private final int _rounds ;
50+ private final Random _seeds ;
4951
5052 public EstimatorLayeredGraph () {
5153 this (ROUNDS );
5254 }
5355
5456 public EstimatorLayeredGraph (int rounds ) {
57+ this (rounds , (int )System .currentTimeMillis ());
58+ }
59+
60+ public EstimatorLayeredGraph (int rounds , int seed ) {
5561 _rounds = rounds ;
62+ _seeds = new Random (seed );
5663 }
5764
5865 @ Override
@@ -73,9 +80,9 @@ public LayeredGraph traverse(MMNode node) {
7380 LayeredGraph ret , left , right ;
7481
7582 left = (node .getLeft ().getData () == null )
76- ? retL : new LayeredGraph (node .getLeft ().getData (), _rounds );
83+ ? retL : new LayeredGraph (node .getLeft ().getData (), _rounds , _seeds . nextInt () );
7784 right = (node .getRight ().getData () == null )
78- ? retR : new LayeredGraph (node .getRight ().getData (), _rounds );
85+ ? retR : new LayeredGraph (node .getRight ().getData (), _rounds , _seeds . nextInt () );
7986
8087 ret = estimInternal (left , right , node .getOp ());
8188
@@ -86,24 +93,24 @@ public LayeredGraph traverse(MMNode node) {
8693 public double estim (MatrixBlock m1 , MatrixBlock m2 , OpCode op ) {
8794 if ( op == OpCode .MM )
8895 return estim (m1 , m2 );
89- LayeredGraph lg1 = new LayeredGraph (m1 , _rounds );
90- LayeredGraph lg2 = new LayeredGraph (m2 , _rounds );
96+ LayeredGraph lg1 = new LayeredGraph (m1 , _rounds , _seeds . nextInt () );
97+ LayeredGraph lg2 = new LayeredGraph (m2 , _rounds , _seeds . nextInt () );
9198 LayeredGraph output = estimInternal (lg1 , lg2 , op );
9299 return OptimizerUtils .getSparsity (
93100 output ._nodes .get (0 ).length , output ._nodes .get (output ._nodes .size () - 1 ).length , output .estimateNnz ());
94101 }
95102
96103 @ Override
97104 public double estim (MatrixBlock m , OpCode op ) {
98- LayeredGraph lg1 = new LayeredGraph (m , _rounds );
105+ LayeredGraph lg1 = new LayeredGraph (m , _rounds , _seeds . nextInt () );
99106 LayeredGraph output = estimInternal (lg1 , null , op );
100107 return OptimizerUtils .getSparsity (
101108 output ._nodes .get (0 ).length , output ._nodes .get (output ._nodes .size () - 1 ).length , output .estimateNnz ());
102109 }
103110
104111 @ Override
105112 public double estim (MatrixBlock m1 , MatrixBlock m2 ) {
106- LayeredGraph graph = new LayeredGraph (Arrays .asList (m1 ,m2 ), _rounds );
113+ LayeredGraph graph = new LayeredGraph (Arrays .asList (m1 ,m2 ), _rounds , _seeds . nextInt () );
107114 return OptimizerUtils .getSparsity (
108115 m1 .getNumRows (), m2 .getNumColumns (), graph .estimateNnz ());
109116 }
@@ -153,16 +160,21 @@ private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
153160 public static class LayeredGraph {
154161 private final List <Node []> _nodes ; //nodes partitioned by graph level
155162 private final int _rounds ; //length of propagated r-vectors
163+ private final Random _seeds ;
156164
157- public LayeredGraph (List < MatrixBlock > chain , int r ) {
165+ public LayeredGraph (int r , int seed ) {
158166 _nodes = new ArrayList <>();
159167 _rounds = r ;
168+ _seeds = new Random (seed );
169+ }
170+
171+ public LayeredGraph (List <MatrixBlock > chain , int r , int seed ) {
172+ this (r , seed );
160173 chain .forEach (i -> buildNext (i ));
161174 }
162175
163- public LayeredGraph (MatrixBlock m , int r ) {
164- _nodes = new ArrayList <>();
165- _rounds = r ;
176+ public LayeredGraph (MatrixBlock m , int r , int seed ) {
177+ this (r , seed );
166178 buildNext (m );
167179 }
168180
@@ -215,7 +227,8 @@ public void buildNext(MatrixBlock mb) {
215227 public long estimateNnz () {
216228 //step 1: assign random vectors ~exp(lambda=1) to all leaf nodes
217229 //(lambda is not the mean, if lambda is 2 mean is 1/2)
218- ExponentialDistribution random = new ExponentialDistribution (new Well1024a (), 1 );
230+ ExponentialDistribution random = new ExponentialDistribution (
231+ new Well1024a (_seeds .nextInt ()), 1 );
219232 for ( Node n : _nodes .get (0 ) ) {
220233 double [] rvect = new double [_rounds ];
221234 for (int g = 0 ; g < _rounds ; g ++)
@@ -234,7 +247,7 @@ private static double calcNNZ(double[] inpvec, int rounds) {
234247 }
235248
236249 public LayeredGraph rbind (LayeredGraph lg ) {
237- LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
250+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds , _seeds . nextInt () );
238251
239252 Node [] rows = new Node [_nodes .get (0 ).length + lg ._nodes .get (0 ).length ];
240253 Node [] columns = _nodes .get (1 ).clone ();
@@ -258,7 +271,7 @@ public LayeredGraph rbind(LayeredGraph lg) {
258271 }
259272
260273 public LayeredGraph cbind (LayeredGraph lg ) {
261- LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
274+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds , _seeds . nextInt () );
262275 int colLength = _nodes .get (1 ).length + lg ._nodes .get (1 ).length ;
263276
264277 Node [] rows = _nodes .get (0 ).clone ();
@@ -286,11 +299,11 @@ public LayeredGraph matMult(LayeredGraph lg) {
286299 List <MatrixBlock > m = Stream .concat (
287300 this .toMatrixBlockList ().stream (), lg .toMatrixBlockList ().stream ())
288301 .collect (Collectors .toList ());
289- return new LayeredGraph (m , _rounds );
302+ return new LayeredGraph (m , _rounds , _seeds . nextInt () );
290303 }
291304
292305 public LayeredGraph or (LayeredGraph lg ) {
293- LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
306+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds , _seeds . nextInt () );
294307 Node [] rows = new Node [_nodes .get (0 ).length ];
295308 for (int i = 0 ; i < _nodes .get (0 ).length ; i ++)
296309 rows [i ] = new Node ();
@@ -319,7 +332,7 @@ public LayeredGraph or(LayeredGraph lg) {
319332 }
320333
321334 public LayeredGraph and (LayeredGraph lg ) {
322- LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
335+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds , _seeds . nextInt () );
323336 Node [] rows = new Node [_nodes .get (0 ).length ];
324337 for (int i = 0 ; i < _nodes .get (0 ).length ; i ++)
325338 rows [i ] = new Node ();
@@ -348,7 +361,7 @@ public LayeredGraph and(LayeredGraph lg) {
348361 }
349362
350363 public LayeredGraph transpose () {
351- LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
364+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds , _seeds . nextInt () );
352365 Node [] rows = new Node [_nodes .get (_nodes .size () - 1 ).length ];
353366 for (int i = 0 ; i < rows .length ; i ++)
354367 rows [i ] = new Node ();
@@ -377,7 +390,7 @@ public LayeredGraph transpose() {
377390 }
378391
379392 public LayeredGraph diag () {
380- LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
393+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds , _seeds . nextInt () );
381394 Node [] rowsOld = _nodes .get (0 );
382395 Node [] columnsOld = _nodes .get (1 );
383396
0 commit comments