3333import java .util .Arrays ;
3434import java .util .List ;
3535import java .util .stream .Collectors ;
36+ import java .util .stream .Stream ;
3637
3738/**
3839 * This estimator implements an approach based on a so-called layered graph,
4344 */
4445public class EstimatorLayeredGraph extends SparsityEstimator {
4546
46- private static final int ROUNDS = 32 ;
47+ private static final int ROUNDS = 512 ;
4748 private final int _rounds ;
4849
4950 public EstimatorLayeredGraph () {
@@ -57,21 +58,47 @@ public EstimatorLayeredGraph(int rounds) {
5758 @ Override
5859 public DataCharacteristics estim (MMNode root ) {
5960 List <MatrixBlock > leafs = getMatrices (root , new ArrayList <>());
60- long nnz = new LayeredGraph (leafs , _rounds ).estimateNnz ();
61+ List <OpCode > ops = getOps (root , new ArrayList <>());
62+ List <LayeredGraph > LGs = new ArrayList <>();
63+ LayeredGraph ret = traverse (root );
64+ long nnz = ret .estimateNnz ();
6165 return root .setDataCharacteristics (new MatrixCharacteristics (
62- leafs .get (0 ).getNumRows (), leafs .get (leafs .size ()-1 ).getNumColumns (), nnz ));
66+ ret ._nodes .get (0 ).length , ret ._nodes .get (ret ._nodes .size () - 1 ).length , nnz ));
67+ }
68+
69+ public LayeredGraph traverse (MMNode node ) {
70+ if (node .getLeft () == null || node .getRight () == null ) return null ;
71+ LayeredGraph retL = traverse (node .getLeft ());
72+ LayeredGraph retR = traverse (node .getRight ());
73+ LayeredGraph ret , left , right ;
74+
75+ left = (node .getLeft ().getData () == null )
76+ ? retL : new LayeredGraph (node .getLeft ().getData (), _rounds );
77+ right = (node .getRight ().getData () == null )
78+ ? retR : new LayeredGraph (node .getRight ().getData (), _rounds );
79+
80+ ret = estimInternal (left , right , node .getOp ());
81+
82+ return ret ;
6383 }
6484
6585 @ Override
6686 public double estim (MatrixBlock m1 , MatrixBlock m2 , OpCode op ) {
6787 if ( op == OpCode .MM )
6888 return estim (m1 , m2 );
69- throw new NotImplementedException ();
89+ LayeredGraph lg1 = new LayeredGraph (m1 , _rounds );
90+ LayeredGraph lg2 = new LayeredGraph (m2 , _rounds );
91+ LayeredGraph output = estimInternal (lg1 , lg2 , op );
92+ return OptimizerUtils .getSparsity (
93+ output ._nodes .get (0 ).length , output ._nodes .get (output ._nodes .size () - 1 ).length , output .estimateNnz ());
7094 }
7195
7296 @ Override
7397 public double estim (MatrixBlock m , OpCode op ) {
74- throw new NotImplementedException ();
98+ LayeredGraph lg1 = new LayeredGraph (m , _rounds );
99+ LayeredGraph output = estimInternal (lg1 , null , op );
100+ return OptimizerUtils .getSparsity (
101+ output ._nodes .get (0 ).length , output ._nodes .get (output ._nodes .size () - 1 ).length , output .estimateNnz ());
75102 }
76103
77104 @ Override
@@ -80,6 +107,23 @@ public double estim(MatrixBlock m1, MatrixBlock m2) {
80107 return OptimizerUtils .getSparsity (
81108 m1 .getNumRows (), m2 .getNumColumns (), graph .estimateNnz ());
82109 }
110+
111+ private static LayeredGraph estimInternal (LayeredGraph lg1 , LayeredGraph lg2 , OpCode op ) {
112+ switch (op ) {
113+ case MM : return lg1 .matMult (lg2 );
114+ case MULT : return lg1 .and (lg2 );
115+ case PLUS : return lg1 .or (lg2 );
116+ case RBIND : return lg1 .rbind (lg2 );
117+ case CBIND : return lg1 .cbind (lg2 );
118+ // case NEQZERO:
119+ // case EQZERO:
120+ case TRANS : return lg1 .transpose ();
121+ case DIAG : return lg1 .diag ();
122+ // case RESHAPE:
123+ default :
124+ throw new NotImplementedException ();
125+ }
126+ }
83127
84128 private List <MatrixBlock > getMatrices (MMNode node , List <MatrixBlock > leafs ) {
85129 //NOTE: this extraction is only correct and efficient for chains, no DAGs
@@ -92,6 +136,18 @@ private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
92136 return leafs ;
93137 }
94138
139+ private List <OpCode > getOps (MMNode node , List <OpCode > ops ) {
140+ //NOTE: this extraction is only correct and efficient for chains, no DAGs
141+ if (node .isLeaf ()) {
142+ }
143+ else {
144+ getOps (node .getLeft (), ops );
145+ getOps (node .getRight (), ops );
146+ ops .add (node .getOp ());
147+ }
148+ return ops ;
149+ }
150+
95151 public static class LayeredGraph {
96152 private final List <Node []> _nodes ; //nodes partitioned by graph level
97153 private final int _rounds ; //length of propagated r-vectors
@@ -101,6 +157,12 @@ public LayeredGraph(List<MatrixBlock> chain, int r) {
101157 _rounds = r ;
102158 chain .forEach (i -> buildNext (i ));
103159 }
160+
161+ public LayeredGraph (MatrixBlock m , int r ) {
162+ _nodes = new ArrayList <>();
163+ _rounds = r ;
164+ buildNext (m );
165+ }
104166
105167 public void buildNext (MatrixBlock mb ) {
106168 if ( mb .isEmpty () )
@@ -168,7 +230,267 @@ private static double calcNNZ(double[] inpvec, int rounds) {
168230 return (inpvec != null && inpvec .length > 0 ) ?
169231 (rounds - 1 ) / Arrays .stream (inpvec ).sum () : 0 ;
170232 }
171-
233+
234+ public LayeredGraph rbind (LayeredGraph lg ) {
235+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
236+
237+ Node [] rows = new Node [_nodes .get (0 ).length + lg ._nodes .get (0 ).length ];
238+ Node [] columns = _nodes .get (1 ).clone ();
239+
240+ System .arraycopy (_nodes .get (0 ), 0 , rows , 0 , _nodes .get (0 ).length );
241+
242+ for (int i = _nodes .get (0 ).length ; i < rows .length ; i ++)
243+ rows [i ] = new Node ();
244+
245+ for (int i = 0 ; i < lg ._nodes .get (0 ).length ; i ++) {
246+ for (int j = 0 ; j < columns .length ; j ++) {
247+ List <Node > edges = lg ._nodes .get (1 )[j ].getInput ();
248+ if (edges .contains (lg ._nodes .get (0 )[i ])) {
249+ columns [j ].addInput (rows [i + _nodes .get (0 ).length ]);
250+ }
251+ }
252+ }
253+ ret ._nodes .add (rows );
254+ ret ._nodes .add (columns );
255+ return ret ;
256+ }
257+
258+ public LayeredGraph cbind (LayeredGraph lg ) {
259+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
260+ int colLength = _nodes .get (1 ).length + lg ._nodes .get (1 ).length ;
261+
262+ Node [] rows = _nodes .get (0 ).clone ();
263+ Node [] columns = new Node [colLength ];
264+
265+ System .arraycopy (_nodes .get (1 ), 0 , columns , 0 , _nodes .get (1 ).length );
266+
267+ for (int i = _nodes .get (1 ).length ; i < columns .length ; i ++)
268+ columns [i ] = new Node ();
269+
270+ for (int i = 0 ; i < rows .length ; i ++) {
271+ for (int j = 0 ; j < lg ._nodes .get (1 ).length ; j ++) {
272+ List <Node > edges = lg ._nodes .get (1 )[j ].getInput ();
273+ if (edges .contains (lg ._nodes .get (0 )[i ])) {
274+ columns [j + _nodes .get (1 ).length ].addInput (rows [i ]);
275+ }
276+ }
277+ }
278+ ret ._nodes .add (rows );
279+ ret ._nodes .add (columns );
280+ return ret ;
281+ }
282+
283+ public LayeredGraph matMult (LayeredGraph lg ) {
284+ List <MatrixBlock > m = Stream .concat (
285+ this .toMatrixBlockList ().stream (), lg .toMatrixBlockList ().stream ())
286+ .collect (Collectors .toList ());
287+ return new LayeredGraph (m , _rounds );
288+ }
289+
290+ public LayeredGraph or (LayeredGraph lg ) {
291+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
292+ Node [] rows = new Node [_nodes .get (0 ).length ];
293+ for (int i = 0 ; i < _nodes .get (0 ).length ; i ++)
294+ rows [i ] = new Node ();
295+ ret ._nodes .add (rows );
296+
297+ for (int x = 0 ; x < _nodes .size () - 1 ; x ++) {
298+ int y = x + 1 ;
299+ rows = ret ._nodes .get (x );
300+ Node [] columns = new Node [_nodes .get (y ).length ];
301+ for (int i = 0 ; i < _nodes .get (y ).length ; i ++)
302+ columns [i ] = new Node ();
303+
304+ for (int i = 0 ; i < _nodes .get (x ).length ; i ++) {
305+ for (int j = 0 ; j < _nodes .get (y ).length ; j ++) {
306+ List <Node > edges1 = _nodes .get (y )[j ].getInput ();
307+ List <Node > edges2 = lg ._nodes .get (y )[j ].getInput ();
308+ if (edges1 .contains (_nodes .get (x )[i ]) || edges2 .contains (lg ._nodes .get (x )[i ]))
309+ {
310+ columns [j ].addInput (rows [i ]);
311+ }
312+ }
313+ }
314+ ret ._nodes .add (columns );
315+ }
316+ return ret ;
317+ }
318+
319+ public LayeredGraph and (LayeredGraph lg ) {
320+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
321+ Node [] rows = new Node [_nodes .get (0 ).length ];
322+ for (int i = 0 ; i < _nodes .get (0 ).length ; i ++)
323+ rows [i ] = new Node ();
324+ ret ._nodes .add (rows );
325+
326+ for (int x = 0 ; x < _nodes .size () - 1 ; x ++) {
327+ int y = x + 1 ;
328+ rows = ret ._nodes .get (x );
329+ Node [] columns = new Node [_nodes .get (y ).length ];
330+ for (int i = 0 ; i < _nodes .get (y ).length ; i ++)
331+ columns [i ] = new Node ();
332+
333+ for (int i = 0 ; i < _nodes .get (x ).length ; i ++) {
334+ for (int j = 0 ; j < _nodes .get (y ).length ; j ++) {
335+ List <Node > edges1 = _nodes .get (y )[j ].getInput ();
336+ List <Node > edges2 = lg ._nodes .get (y )[j ].getInput ();
337+ if (edges1 .contains (_nodes .get (x )[i ]) && edges2 .contains (lg ._nodes .get (x )[i ]))
338+ {
339+ columns [j ].addInput (rows [i ]);
340+ }
341+ }
342+ }
343+ ret ._nodes .add (columns );
344+ }
345+ return ret ;
346+ }
347+
348+ public LayeredGraph transpose () {
349+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
350+ Node [] rows = new Node [_nodes .get (_nodes .size () - 1 ).length ];
351+ for (int i = 0 ; i < rows .length ; i ++)
352+ rows [i ] = new Node ();
353+ ret ._nodes .add (rows );
354+
355+ for (int x = _nodes .size () - 1 ; x > 0 ; x --) {
356+ rows = ret ._nodes .get (ret ._nodes .size () - 1 );
357+ Node [] columnsOld = _nodes .get (x );
358+ Node [] rowsOld = _nodes .get (x - 1 );
359+ Node [] columns = new Node [rowsOld .length ];
360+
361+ for (int i = 0 ; i < rowsOld .length ; i ++)
362+ columns [i ] = new Node ();
363+
364+ for (int i = 0 ; i < rowsOld .length ; i ++) {
365+ for (int j = 0 ; j < columnsOld .length ; j ++) {
366+ List <Node > edges = columnsOld [j ].getInput ();
367+ if (edges .contains (rowsOld [i ])) {
368+ columns [i ].addInput (rows [j ]);
369+ }
370+ }
371+ }
372+ ret ._nodes .add (columns );
373+ }
374+ return ret ;
375+ }
376+
377+ public LayeredGraph diag () {
378+ LayeredGraph ret = new LayeredGraph (List .of (), _rounds );
379+ Node [] rowsOld = _nodes .get (0 );
380+ Node [] columnsOld = _nodes .get (1 );
381+
382+ if (_nodes .get (1 ).length == 1 ) {
383+ Node [] rows = new Node [rowsOld .length ];
384+ Node [] columns = new Node [rowsOld .length ];
385+
386+ for (int i = 0 ; i < rowsOld .length ; i ++)
387+ rows [i ] = new Node ();
388+ for (int i = 0 ; i < rowsOld .length ; i ++)
389+ columns [i ] = new Node ();
390+
391+ List <Node > edges = columnsOld [0 ].getInput ();
392+ for (int i = 0 ; i < rowsOld .length ; i ++) {
393+ for (int j = 0 ; j < rowsOld .length ; j ++) {
394+ if (edges .contains (rowsOld [i ]) && i == j ) {
395+ columns [j ].addInput (rows [i ]);
396+ }
397+ }
398+ }
399+ ret ._nodes .add (rows );
400+ ret ._nodes .add (columns );
401+ return ret ;
402+ }
403+ else if (_nodes .get (0 ).length == 1 ){
404+ Node [] rows = new Node [columnsOld .length ];
405+ Node [] columns = new Node [columnsOld .length ];
406+
407+ for (int i = 0 ; i < columnsOld .length ; i ++)
408+ rows [i ] = new Node ();
409+ for (int i = 0 ; i < columnsOld .length ; i ++)
410+ columns [i ] = new Node ();
411+
412+ for (int i = 0 ; i < columnsOld .length ; i ++) {
413+ for (int j = 0 ; j < columnsOld .length ; j ++) {
414+ List <Node > edges = columnsOld [j ].getInput ();
415+ if (edges .contains (rowsOld [0 ]) && i == j ) {
416+ columns [j ].addInput (rows [i ]);
417+ }
418+ }
419+ }
420+ ret ._nodes .add (rows );
421+ ret ._nodes .add (columns );
422+ return ret ;
423+ }
424+ else {
425+ Node [] rows = new Node [rowsOld .length ];
426+ Node [] columns = new Node [1 ];
427+ for (int i = 0 ; i < rowsOld .length ; i ++)
428+ rows [i ] = new Node ();
429+ for (int i = 0 ; i < 1 ; i ++)
430+ columns [i ] = new Node ();
431+ for (int i = 0 ; i < rowsOld .length ; i ++) {
432+ for (int j = 0 ; j < columnsOld .length ; j ++) {
433+ List <Node > edges = columnsOld [j ].getInput ();
434+ if (edges .contains (rowsOld [i ]) && i == j ) {
435+ columns [0 ].addInput (rows [i ]);
436+ }
437+ }
438+ }
439+ ret ._nodes .add (rows );
440+ ret ._nodes .add (columns );
441+ return ret ;
442+ }
443+ }
444+
445+ public MatrixBlock toMatrixBlock () {
446+ List <Double > a = new ArrayList <>();
447+ int rows = _nodes .get (0 ).length ;
448+ int cols = _nodes .get (1 ).length ;
449+ for (int i = 0 ; i < rows * cols ; i ++) {
450+ a .add (0. );
451+ }
452+ for (int i = 0 ; i < rows ; i ++) {
453+ for (int j = 0 ; j < cols ; j ++) {
454+ List <Node > edges = _nodes .get (1 )[j ].getInput ();
455+ if (edges .contains (_nodes .get (0 )[i ])) {
456+ a .set (i * cols + j , 1. + a .get (i * cols + j ));
457+ }
458+ else {
459+ a .set (i * cols + j , 0. );
460+ }
461+ }
462+ }
463+ double [] arr = a .stream ().mapToDouble (d -> d ).toArray ();
464+ return new MatrixBlock (rows , cols , arr );
465+ }
466+
467+ public List <MatrixBlock > toMatrixBlockList () {
468+ List <MatrixBlock > m = new ArrayList <>();
469+ for (int x = 0 ; x < _nodes .size () - 1 ; x ++) {
470+ int y = x + 1 ;
471+ List <Double > a = new ArrayList <>();
472+ int rows = _nodes .get (x ).length ;
473+ int cols = _nodes .get (y ).length ;
474+ for (int i = 0 ; i < rows * cols ; i ++) {
475+ a .add (0. );
476+ }
477+ for (int i = 0 ; i < rows ; i ++) {
478+ for (int j = 0 ; j < cols ; j ++) {
479+ List <Node > edges = _nodes .get (y )[j ].getInput ();
480+ if (edges .contains (_nodes .get (x )[i ])) {
481+ a .set (i * cols + j , 1. + a .get (i * cols + j ));
482+ }
483+ else {
484+ a .set (i * cols + j , 0. );
485+ }
486+ }
487+ }
488+ double [] arr = a .stream ().mapToDouble (d -> d ).toArray ();
489+ m .add (new MatrixBlock (rows , cols , arr ));
490+ }
491+ return m ;
492+ }
493+
172494 private static class Node {
173495 private List <Node > _input = new ArrayList <>();
174496 private double [] _rvect ;
0 commit comments