Skip to content

Commit f86106c

Browse files
MRGSRTmboehm7
authored andcommitted
[SYSTEMDS-2285] Extended Sparsity Estimation Tests and Baselines
Closes #1945.
1 parent e1377d2 commit f86106c

File tree

11 files changed

+444
-48
lines changed

11 files changed

+444
-48
lines changed

src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java

Lines changed: 328 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Arrays;
3434
import java.util.List;
3535
import java.util.stream.Collectors;
36+
import java.util.stream.Stream;
3637

3738
/**
3839
* This estimator implements an approach based on a so-called layered graph,
@@ -43,7 +44,7 @@
4344
*/
4445
public class EstimatorLayeredGraph extends SparsityEstimator {
4546

46-
private static final int ROUNDS = 32;
47+
private static final int ROUNDS = 512;
4748
private final int _rounds;
4849

4950
public EstimatorLayeredGraph() {
@@ -57,21 +58,47 @@ public EstimatorLayeredGraph(int rounds) {
5758
@Override
5859
public DataCharacteristics estim(MMNode root) {
5960
List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>());
60-
long nnz = new LayeredGraph(leafs, _rounds).estimateNnz();
61+
List<OpCode> ops = getOps(root, new ArrayList<>());
62+
List<LayeredGraph> LGs = new ArrayList<>();
63+
LayeredGraph ret = traverse(root);
64+
long nnz = ret.estimateNnz();
6165
return root.setDataCharacteristics(new MatrixCharacteristics(
62-
leafs.get(0).getNumRows(), leafs.get(leafs.size()-1).getNumColumns(), nnz));
66+
ret._nodes.get(0).length, ret._nodes.get(ret._nodes.size() - 1).length, nnz));
67+
}
68+
69+
public LayeredGraph traverse(MMNode node) {
70+
if(node.getLeft() == null || node.getRight() == null) return null;
71+
LayeredGraph retL = traverse(node.getLeft());
72+
LayeredGraph retR = traverse(node.getRight());
73+
LayeredGraph ret, left, right;
74+
75+
left = (node.getLeft().getData() == null)
76+
? retL : new LayeredGraph(node.getLeft().getData(), _rounds);
77+
right = (node.getRight().getData() == null)
78+
? retR : new LayeredGraph(node.getRight().getData(), _rounds);
79+
80+
ret = estimInternal(left, right, node.getOp());
81+
82+
return ret;
6383
}
6484

6585
@Override
6686
public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
6787
if( op == OpCode.MM )
6888
return estim(m1, m2);
69-
throw new NotImplementedException();
89+
LayeredGraph lg1 = new LayeredGraph(m1, _rounds);
90+
LayeredGraph lg2 = new LayeredGraph(m2, _rounds);
91+
LayeredGraph output = estimInternal(lg1, lg2, op);
92+
return OptimizerUtils.getSparsity(
93+
output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz());
7094
}
7195

7296
@Override
7397
public double estim(MatrixBlock m, OpCode op) {
74-
throw new NotImplementedException();
98+
LayeredGraph lg1 = new LayeredGraph(m, _rounds);
99+
LayeredGraph output = estimInternal(lg1, null, op);
100+
return OptimizerUtils.getSparsity(
101+
output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz());
75102
}
76103

77104
@Override
@@ -80,6 +107,23 @@ public double estim(MatrixBlock m1, MatrixBlock m2) {
80107
return OptimizerUtils.getSparsity(
81108
m1.getNumRows(), m2.getNumColumns(), graph.estimateNnz());
82109
}
110+
111+
private static LayeredGraph estimInternal(LayeredGraph lg1, LayeredGraph lg2, OpCode op) {
112+
switch(op) {
113+
case MM: return lg1.matMult(lg2);
114+
case MULT: return lg1.and(lg2);
115+
case PLUS: return lg1.or(lg2);
116+
case RBIND: return lg1.rbind(lg2);
117+
case CBIND: return lg1.cbind(lg2);
118+
// case NEQZERO:
119+
// case EQZERO:
120+
case TRANS: return lg1.transpose();
121+
case DIAG: return lg1.diag();
122+
// case RESHAPE:
123+
default:
124+
throw new NotImplementedException();
125+
}
126+
}
83127

84128
private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
85129
//NOTE: this extraction is only correct and efficient for chains, no DAGs
@@ -92,6 +136,18 @@ private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
92136
return leafs;
93137
}
94138

139+
private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
140+
//NOTE: this extraction is only correct and efficient for chains, no DAGs
141+
if(node.isLeaf()) {
142+
}
143+
else {
144+
getOps(node.getLeft(), ops);
145+
getOps(node.getRight(), ops);
146+
ops.add(node.getOp());
147+
}
148+
return ops;
149+
}
150+
95151
public static class LayeredGraph {
96152
private final List<Node[]> _nodes; //nodes partitioned by graph level
97153
private final int _rounds; //length of propagated r-vectors
@@ -101,6 +157,12 @@ public LayeredGraph(List<MatrixBlock> chain, int r) {
101157
_rounds = r;
102158
chain.forEach(i -> buildNext(i));
103159
}
160+
161+
public LayeredGraph(MatrixBlock m, int r) {
162+
_nodes = new ArrayList<>();
163+
_rounds = r;
164+
buildNext(m);
165+
}
104166

105167
public void buildNext(MatrixBlock mb) {
106168
if( mb.isEmpty() )
@@ -168,7 +230,267 @@ private static double calcNNZ(double[] inpvec, int rounds) {
168230
return (inpvec != null && inpvec.length > 0) ?
169231
(rounds - 1) / Arrays.stream(inpvec).sum() : 0;
170232
}
171-
233+
234+
public LayeredGraph rbind(LayeredGraph lg) {
235+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
236+
237+
Node[] rows = new Node[_nodes.get(0).length + lg._nodes.get(0).length];
238+
Node[] columns = _nodes.get(1).clone();
239+
240+
System.arraycopy(_nodes.get(0), 0, rows, 0, _nodes.get(0).length);
241+
242+
for (int i = _nodes.get(0).length; i < rows.length; i++)
243+
rows[i] = new Node();
244+
245+
for(int i = 0; i < lg._nodes.get(0).length; i++) {
246+
for(int j = 0; j < columns.length; j++) {
247+
List<Node> edges = lg._nodes.get(1)[j].getInput();
248+
if(edges.contains(lg._nodes.get(0)[i])) {
249+
columns[j].addInput(rows[i + _nodes.get(0).length]);
250+
}
251+
}
252+
}
253+
ret._nodes.add(rows);
254+
ret._nodes.add(columns);
255+
return ret;
256+
}
257+
258+
public LayeredGraph cbind(LayeredGraph lg) {
259+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
260+
int colLength = _nodes.get(1).length + lg._nodes.get(1).length;
261+
262+
Node[] rows = _nodes.get(0).clone();
263+
Node[] columns = new Node[colLength];
264+
265+
System.arraycopy(_nodes.get(1), 0, columns, 0, _nodes.get(1).length);
266+
267+
for (int i = _nodes.get(1).length; i < columns.length; i++)
268+
columns[i] = new Node();
269+
270+
for(int i = 0; i < rows.length; i++) {
271+
for(int j = 0; j < lg._nodes.get(1).length; j++) {
272+
List<Node> edges = lg._nodes.get(1)[j].getInput();
273+
if(edges.contains(lg._nodes.get(0)[i])) {
274+
columns[j + _nodes.get(1).length].addInput(rows[i]);
275+
}
276+
}
277+
}
278+
ret._nodes.add(rows);
279+
ret._nodes.add(columns);
280+
return ret;
281+
}
282+
283+
public LayeredGraph matMult(LayeredGraph lg) {
284+
List<MatrixBlock> m = Stream.concat(
285+
this.toMatrixBlockList().stream(), lg.toMatrixBlockList().stream())
286+
.collect(Collectors.toList());
287+
return new LayeredGraph(m, _rounds);
288+
}
289+
290+
public LayeredGraph or(LayeredGraph lg) {
291+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
292+
Node[] rows = new Node[_nodes.get(0).length];
293+
for (int i = 0; i < _nodes.get(0).length; i++)
294+
rows[i] = new Node();
295+
ret._nodes.add(rows);
296+
297+
for(int x = 0; x < _nodes.size() - 1; x++) {
298+
int y = x + 1;
299+
rows = ret._nodes.get(x);
300+
Node[] columns = new Node[_nodes.get(y).length];
301+
for (int i = 0; i < _nodes.get(y).length; i++)
302+
columns[i] = new Node();
303+
304+
for(int i = 0; i < _nodes.get(x).length; i++) {
305+
for(int j = 0; j < _nodes.get(y).length; j++) {
306+
List<Node> edges1 = _nodes.get(y)[j].getInput();
307+
List<Node> edges2 = lg._nodes.get(y)[j].getInput();
308+
if(edges1.contains(_nodes.get(x)[i]) || edges2.contains(lg._nodes.get(x)[i]))
309+
{
310+
columns[j].addInput(rows[i]);
311+
}
312+
}
313+
}
314+
ret._nodes.add(columns);
315+
}
316+
return ret;
317+
}
318+
319+
public LayeredGraph and(LayeredGraph lg) {
320+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
321+
Node[] rows = new Node[_nodes.get(0).length];
322+
for (int i = 0; i < _nodes.get(0).length; i++)
323+
rows[i] = new Node();
324+
ret._nodes.add(rows);
325+
326+
for(int x = 0; x < _nodes.size() - 1; x++) {
327+
int y = x + 1;
328+
rows = ret._nodes.get(x);
329+
Node[] columns = new Node[_nodes.get(y).length];
330+
for (int i = 0; i < _nodes.get(y).length; i++)
331+
columns[i] = new Node();
332+
333+
for(int i = 0; i < _nodes.get(x).length; i++) {
334+
for(int j = 0; j < _nodes.get(y).length; j++) {
335+
List<Node> edges1 = _nodes.get(y)[j].getInput();
336+
List<Node> edges2 = lg._nodes.get(y)[j].getInput();
337+
if(edges1.contains(_nodes.get(x)[i]) && edges2.contains(lg._nodes.get(x)[i]))
338+
{
339+
columns[j].addInput(rows[i]);
340+
}
341+
}
342+
}
343+
ret._nodes.add(columns);
344+
}
345+
return ret;
346+
}
347+
348+
public LayeredGraph transpose() {
349+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
350+
Node[] rows = new Node[_nodes.get(_nodes.size() - 1).length];
351+
for (int i = 0; i < rows.length; i++)
352+
rows[i] = new Node();
353+
ret._nodes.add(rows);
354+
355+
for(int x = _nodes.size() - 1; x > 0; x--) {
356+
rows = ret._nodes.get(ret._nodes.size() - 1);
357+
Node[] columnsOld = _nodes.get(x);
358+
Node[] rowsOld = _nodes.get(x - 1);
359+
Node[] columns = new Node[rowsOld.length];
360+
361+
for (int i = 0; i < rowsOld.length; i++)
362+
columns[i] = new Node();
363+
364+
for(int i = 0; i < rowsOld.length; i++) {
365+
for(int j = 0; j < columnsOld.length; j++) {
366+
List<Node> edges = columnsOld[j].getInput();
367+
if(edges.contains(rowsOld[i])) {
368+
columns[i].addInput(rows[j]);
369+
}
370+
}
371+
}
372+
ret._nodes.add(columns);
373+
}
374+
return ret;
375+
}
376+
377+
public LayeredGraph diag() {
378+
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
379+
Node[] rowsOld = _nodes.get(0);
380+
Node[] columnsOld = _nodes.get(1);
381+
382+
if(_nodes.get(1).length == 1) {
383+
Node[] rows = new Node[rowsOld.length];
384+
Node[] columns = new Node[rowsOld.length];
385+
386+
for (int i = 0; i < rowsOld.length; i++)
387+
rows[i] = new Node();
388+
for (int i = 0; i < rowsOld.length; i++)
389+
columns[i] = new Node();
390+
391+
List<Node> edges = columnsOld[0].getInput();
392+
for(int i = 0; i < rowsOld.length; i++) {
393+
for(int j = 0; j < rowsOld.length; j++) {
394+
if(edges.contains(rowsOld[i]) && i == j) {
395+
columns[j].addInput(rows[i]);
396+
}
397+
}
398+
}
399+
ret._nodes.add(rows);
400+
ret._nodes.add(columns);
401+
return ret;
402+
}
403+
else if(_nodes.get(0).length == 1){
404+
Node[] rows = new Node[columnsOld.length];
405+
Node[] columns = new Node[columnsOld.length];
406+
407+
for (int i = 0; i < columnsOld.length; i++)
408+
rows[i] = new Node();
409+
for (int i = 0; i < columnsOld.length; i++)
410+
columns[i] = new Node();
411+
412+
for(int i = 0; i < columnsOld.length; i++) {
413+
for(int j = 0; j < columnsOld.length; j++) {
414+
List<Node> edges = columnsOld[j].getInput();
415+
if(edges.contains(rowsOld[0]) && i == j) {
416+
columns[j].addInput(rows[i]);
417+
}
418+
}
419+
}
420+
ret._nodes.add(rows);
421+
ret._nodes.add(columns);
422+
return ret;
423+
}
424+
else {
425+
Node[] rows = new Node[rowsOld.length];
426+
Node[] columns = new Node[1];
427+
for (int i = 0; i < rowsOld.length; i++)
428+
rows[i] = new Node();
429+
for (int i = 0; i < 1; i++)
430+
columns[i] = new Node();
431+
for(int i = 0; i < rowsOld.length; i++) {
432+
for(int j = 0; j < columnsOld.length; j++) {
433+
List<Node> edges = columnsOld[j].getInput();
434+
if(edges.contains(rowsOld[i]) && i == j) {
435+
columns[0].addInput(rows[i]);
436+
}
437+
}
438+
}
439+
ret._nodes.add(rows);
440+
ret._nodes.add(columns);
441+
return ret;
442+
}
443+
}
444+
445+
public MatrixBlock toMatrixBlock() {
446+
List<Double> a = new ArrayList<>();
447+
int rows = _nodes.get(0).length;
448+
int cols = _nodes.get(1).length;
449+
for(int i = 0; i < rows * cols; i++) {
450+
a.add(0.);
451+
}
452+
for(int i = 0; i < rows; i++) {
453+
for(int j = 0; j < cols; j++) {
454+
List<Node> edges = _nodes.get(1)[j].getInput();
455+
if(edges.contains(_nodes.get(0)[i])) {
456+
a.set(i * cols + j, 1. + a.get(i * cols + j));
457+
}
458+
else {
459+
a.set(i * cols + j, 0.);
460+
}
461+
}
462+
}
463+
double[] arr = a.stream().mapToDouble(d -> d).toArray();
464+
return new MatrixBlock(rows, cols, arr);
465+
}
466+
467+
public List<MatrixBlock> toMatrixBlockList() {
468+
List<MatrixBlock> m = new ArrayList<>();
469+
for(int x = 0; x < _nodes.size() - 1; x++) {
470+
int y = x + 1;
471+
List<Double> a = new ArrayList<>();
472+
int rows = _nodes.get(x).length;
473+
int cols = _nodes.get(y).length;
474+
for(int i = 0; i < rows * cols; i++) {
475+
a.add(0.);
476+
}
477+
for(int i = 0; i < rows; i++) {
478+
for(int j = 0; j < cols; j++) {
479+
List<Node> edges = _nodes.get(y)[j].getInput();
480+
if(edges.contains(_nodes.get(x)[i])) {
481+
a.set(i * cols + j, 1. + a.get(i * cols + j));
482+
}
483+
else {
484+
a.set(i * cols + j, 0.);
485+
}
486+
}
487+
}
488+
double[] arr = a.stream().mapToDouble(d -> d).toArray();
489+
m.add(new MatrixBlock(rows, cols, arr));
490+
}
491+
return m;
492+
}
493+
172494
private static class Node {
173495
private List<Node> _input = new ArrayList<>();
174496
private double[] _rvect;

0 commit comments

Comments
 (0)