Skip to content

Commit 1a6becf

Browse files
clean tests and improve explain prints
1 parent 911c796 commit 1a6becf

File tree

3 files changed

+57
-92
lines changed

3 files changed

+57
-92
lines changed

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ public String[] recursivePrintString() {
112112
String[] res = new String[left.length + right.length+1];
113113
res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString();
114114
for (int i=0; i<left.length; i++) {
115-
res[i+1] = (i==0 ? "┌─ " : "| ") +left[i];
115+
res[i+1] = (i==0 ? "┌─ " : " ") +left[i];
116116
}
117117
for (int i=0; i<right.length; i++) {
118-
res[left.length+i+1] = (i==0 ? "└─ " : "| ") +right[i];
118+
res[left.length+i+1] = (i==0 ? "└─ " : " ") +right[i];
119119
}
120120
return res;
121121
}

src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand ou
7575
Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.WARN);
7676
}
7777

78-
@SuppressWarnings("unused")
79-
private EinsumContext einc = null;
80-
8178
@Override
8279
public void processInstruction(ExecutionContext ec) {
8380
//get input matrices and scalars, incl pinning of matrices
@@ -97,7 +94,6 @@ public void processInstruction(ExecutionContext ec) {
9794

9895
EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs);
9996

100-
this.einc = einc;
10197
String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : "";
10298

10399
if( LOG.isTraceEnabled() ) LOG.trace("output: "+resultString +" "+einc.outRows+"x"+einc.outCols);
@@ -193,8 +189,9 @@ public void processInstruction(ExecutionContext ec) {
193189
}
194190
if (EXPLAIN != Explain.ExplainType.NONE )
195191
System.out.println("Einsum plan:");
196-
for(var pl : plan){
197-
System.out.println("- "+String.join("\n- ", pl.recursivePrintString()));
192+
for(int i = 0; i < plan.size(); i++) {
193+
System.out.println((i+1)+".");
194+
System.out.println("- "+String.join("\n- ", plan.get(i).recursivePrintString()));
198195
}
199196

200197
remainingMatrices = executePlan(plan, inputs);

src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java

Lines changed: 52 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -47,108 +47,78 @@
4747
public class EinsumTest extends AutomatedTestBase
4848
{
4949
final private static List<Config> TEST_CONFIGS = List.of(
50-
5150
new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // mm
5251
new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))),
5352
new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))),
5453
new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))),
55-
// new Config("ab,cb,zc->az", List.of(shape(500, 900), shape(1000, 900), shape(400, 1000))),
56-
57-
new Config("ji,jk->i", List.of(shape(60, 5), shape(60, 10))),
58-
new Config("ij,jk->i", List.of(shape(5, 60), shape(60, 10))),
59-
60-
new Config("ji,jk->k", List.of(shape(60, 5), shape(60, 10))),
61-
new Config("ij,jk->k", List.of(shape(5, 60), shape(60, 10))),
54+
new Config("ab,bc,cd,de->ae", List.of(shape(5, 6), shape(6, 5),shape(5, 6), shape(6, 5))), // mm chain
6255

63-
new Config("ji,jk->j", List.of(shape(60, 5), shape(60, 10))),
56+
new Config("ji,jk->i", List.of(shape(6, 5), shape(6, 4))),
57+
new Config("ij,jk->i", List.of(shape(5, 6), shape(6, 4))),
58+
new Config("ji,jk->k", List.of(shape(6, 5), shape(6, 4))),
59+
new Config("ij,jk->k", List.of(shape(5, 6), shape(6, 4))),
60+
new Config("ji,jk->j", List.of(shape(6, 5), shape(6, 4))),
6461

6562
new Config("ji,ji->ji", List.of(shape(60, 5), shape(60, 5))), // elemwise mult
66-
new Config("ji,ji,ji->ji", List.of(shape(60, 5),shape(60, 5), shape(60, 5)),
67-
List.of(0.0001, 0.0005, 0.001)),
6863
new Config("ji,ij->ji", List.of(shape(60, 5), shape(5, 60))), // elemwise mult
6964

70-
7165
new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult
7266
new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult
7367
new Config("ij,i->i", List.of(shape(10, 5), shape(10))),
7468
new Config("ij,i->j", List.of(shape(10, 5), shape(10))),
75-
//
76-
new Config("i,i->", List.of(shape(5), shape(5))),
77-
new Config("i,j->", List.of(shape(5), shape(80))),
69+
70+
new Config("i,i->", List.of(shape(5), shape(5))), // dot
71+
new Config("i,j->", List.of(shape(5), shape(80))), // sum
7872
new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult
7973
new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult
8074

8175
new Config("ij->", List.of(shape(10, 5))), // sum
76+
new Config("i->", List.of(shape(10))), // sum
8277
new Config("ij->i", List.of(shape(10, 5))), // sum(1)
8378
new Config("ij->j", List.of(shape(10, 5))), // sum(0)
84-
new Config("ij->ji", List.of(shape(10, 5))), // T
85-
86-
new Config("ab,cd->ba", List.of(shape( 60, 10), shape(6, 5))),
87-
new Config("ab,cd,g->ba", List.of(shape( 60, 10), shape(6, 5), shape(3))),
88-
//
89-
new Config("ab,bc,cd,de->ae", List.of(shape(5, 60), shape(60, 10), shape(10, 5), shape(5, 4))), // chain of mm
90-
//
91-
// new Config("ji,jz,zx->ix", List.of(shape(60, 5), shape( 60, 10), shape(10, 2))),
92-
// new Config("fx,fg,fz,xg->z", List.of(shape(60, 5), shape( 60, 10), shape(60, 6), shape(5, 10))),
93-
new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl)
94-
List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))),
95-
//
96-
new Config("i->", List.of(shape(10))),
97-
new Config("i->i", List.of(shape(10))),
98-
99-
// test fused
100-
new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
101-
new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
102-
new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
103-
new Config("ij,i,j->ij", List.of(shape(10, 5), shape(10),shape(5))),
104-
new Config("ij,i,i->ij", List.of(shape(10, 5), shape(10),shape(10)), List.of(0.01,0.02,0.1)),
105-
new Config("ij,j,j->ij", List.of(shape(10, 5), shape(5),shape(5))),
106-
new Config("ij,i,j->i", List.of(shape(10, 5), shape(10),shape(5))),
107-
new Config("ij,i,i->i", List.of(shape(10, 5), shape(10),shape(10))),
108-
// new Config("ij,j,j->i", List.of(shape(10, 5), shape(5),shape(5))),
109-
new Config("ij,i,j->j", List.of(shape(10, 5), shape(10),shape(5))),
110-
// new Config("ij,i,i->j", List.of(shape(10, 5), shape(10),shape(10))),
111-
// new Config("ij,j,j->j", List.of(shape(10, 5), shape(5),shape(5))),
112-
new Config("ij,i,j->", List.of(shape(10, 5), shape(10),shape(5))),
113-
// new Config("ij,i,i->", List.of(shape(10, 5), shape(10),shape(10))),
114-
// new Config("ij,j,j->", List.of(shape(10, 5), shape(5),shape(5))),
115-
116-
// test fuesed:
117-
new Config("ij,ij,ji,i,j->i", List.of(shape(7, 5), shape(7, 5),shape(5, 7),shape(7),shape(5))),
118-
new Config("ij,i,i,j,j->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50))),
119-
new Config("ij,i,i,j,j,z->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50),shape(2)),List.of(1.0,1.0,1.0,1.0,1.0,1.0) ), // include scalar to tmpl
120-
new Config("ij,ij,ij,i,j->j", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5))),
121-
new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))),
122-
// new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))),
123-
124-
125-
new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))),
126-
new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))),
127-
new Config("ij,i,j->j", List.of(shape(100, 5),shape(100),shape(5))),
128-
new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10))),
129-
new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
130-
new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10))),
131-
new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10))),
132-
new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5),shape(5, 10))),
133-
new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
134-
new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
135-
new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
136-
new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
137-
new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
138-
new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
139-
new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
140-
new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
141-
new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
142-
new Config("ij,ij,ji,j,i,ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)),
143-
//skinny right:
144-
new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // with outer mm
145-
// no skinny right:
146-
new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // with outer mm
147-
new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)) // with outer mm
148-
,new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10)))
149-
150-
// ,new Config("ij,ij,ji->ij", List.of(shape(100, 50), shape(100, 50),shape(50, 100)), List.of(0.1,1.0,1.0))
79+
new Config("ij->ji", List.of(shape(10, 5))), // T
80+
new Config("ij->ij", List.of(shape(10, 5))),
81+
new Config("i->i", List.of(shape(10))),
82+
new Config("ii->i", List.of(shape(10, 10))), // Diag
83+
new Config("ii,i->i", List.of(shape(10, 10),shape(10))), // Diag*vec
84+
85+
new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6, 5))), // sum cd to scalar and multiply ab
15186

87+
new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl fallback)
88+
List.of(shape(5, 6), shape(5, 3), shape(5, 10), shape(6, 3), shape(10, 6), shape(10, 3))),
89+
90+
// test fused:
91+
new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5), shape(5, 10))),
92+
new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
93+
new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
94+
new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
95+
96+
new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5), shape(5, 10))),
97+
new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
98+
new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
99+
new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
100+
101+
new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5), shape(5, 10))),
102+
new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
103+
new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
104+
new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
105+
106+
new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5), shape(5, 10))),
107+
new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))),
108+
new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))),
109+
new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))),
110+
111+
new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6))),
112+
new Config("ij,ij,ij,i,j,iz,z->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6),shape(6))),
113+
114+
new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))),
115+
new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))),
116+
117+
new Config("ij,ij,ji,j,i, ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)),
118+
new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // //skinny right with outer mm
119+
new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // // no skinny right
120+
new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)),
121+
new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10)))
152122
);
153123
private final int id;
154124
private final String einsumStr;
@@ -229,7 +199,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar)
229199
sb.append("\n");
230200
}
231201
sb.append("\n");
232-
// sb.append("for (i in 1:5) {\n");
233202
sb.append("R = einsum(\"");
234203
sb.append(config.einsumStr);
235204
sb.append("\", ");
@@ -242,7 +211,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar)
242211
sb.append("A");
243212
sb.append(config.shapes.size() - 1);
244213
sb.append(")");
245-
// sb.append("\n}\n");
246214

247215
sb.append("\n\n");
248216
sb.append("write(R, $1)\n");

0 commit comments

Comments
 (0)