4747public class EinsumTest extends AutomatedTestBase
4848{
4949 final private static List <Config > TEST_CONFIGS = List .of (
50-
5150 new Config ("ij,jk->ik" , List .of (shape (5 , 6 ), shape (6 , 5 ))), // mm
5251 new Config ("ji,jk->ik" , List .of (shape (6 , 5 ), shape (6 , 10 ))),
5352 new Config ("ji,kj->ik" , List .of (shape (6 , 5 ), shape (10 , 6 ))),
5453 new Config ("ij,kj->ik" , List .of (shape (5 , 6 ), shape (10 , 6 ))),
55- // new Config("ab,cb,zc->az", List.of(shape(500, 900), shape(1000, 900), shape(400, 1000))),
56-
57- new Config ("ji,jk->i" , List .of (shape (60 , 5 ), shape (60 , 10 ))),
58- new Config ("ij,jk->i" , List .of (shape (5 , 60 ), shape (60 , 10 ))),
59-
60- new Config ("ji,jk->k" , List .of (shape (60 , 5 ), shape (60 , 10 ))),
61- new Config ("ij,jk->k" , List .of (shape (5 , 60 ), shape (60 , 10 ))),
54+ new Config ("ab,bc,cd,de->ae" , List .of (shape (5 , 6 ), shape (6 , 5 ),shape (5 , 6 ), shape (6 , 5 ))), // mm chain
6255
63- new Config ("ji,jk->j" , List .of (shape (60 , 5 ), shape (60 , 10 ))),
56+ new Config ("ji,jk->i" , List .of (shape (6 , 5 ), shape (6 , 4 ))),
57+ new Config ("ij,jk->i" , List .of (shape (5 , 6 ), shape (6 , 4 ))),
58+ new Config ("ji,jk->k" , List .of (shape (6 , 5 ), shape (6 , 4 ))),
59+ new Config ("ij,jk->k" , List .of (shape (5 , 6 ), shape (6 , 4 ))),
60+ new Config ("ji,jk->j" , List .of (shape (6 , 5 ), shape (6 , 4 ))),
6461
6562 new Config ("ji,ji->ji" , List .of (shape (60 , 5 ), shape (60 , 5 ))), // elemwise mult
66- new Config ("ji,ji,ji->ji" , List .of (shape (60 , 5 ),shape (60 , 5 ), shape (60 , 5 )),
67- List .of (0.0001 , 0.0005 , 0.001 )),
6863 new Config ("ji,ij->ji" , List .of (shape (60 , 5 ), shape (5 , 60 ))), // elemwise mult
6964
70-
7165 new Config ("ij,i->ij" , List .of (shape (10 , 5 ), shape (10 ))), // col mult
7266 new Config ("ji,i->ij" , List .of (shape (5 , 10 ), shape (10 ))), // row mult
7367 new Config ("ij,i->i" , List .of (shape (10 , 5 ), shape (10 ))),
7468 new Config ("ij,i->j" , List .of (shape (10 , 5 ), shape (10 ))),
75- //
76- new Config ("i,i->" , List .of (shape (5 ), shape (5 ))),
77- new Config ("i,j->" , List .of (shape (5 ), shape (80 ))),
69+
70+ new Config ("i,i->" , List .of (shape (5 ), shape (5 ))), // dot
71+ new Config ("i,j->" , List .of (shape (5 ), shape (80 ))), // sum
7872 new Config ("i,j->ij" , List .of (shape (5 ), shape (80 ))), // outer vect mult
7973 new Config ("i,j->ji" , List .of (shape (5 ), shape (80 ))), // outer vect mult
8074
8175 new Config ("ij->" , List .of (shape (10 , 5 ))), // sum
76+ new Config ("i->" , List .of (shape (10 ))), // sum
8277 new Config ("ij->i" , List .of (shape (10 , 5 ))), // sum(1)
8378 new Config ("ij->j" , List .of (shape (10 , 5 ))), // sum(0)
84- new Config ("ij->ji" , List .of (shape (10 , 5 ))), // T
85-
86- new Config ("ab,cd->ba" , List .of (shape ( 60 , 10 ), shape (6 , 5 ))),
87- new Config ("ab,cd,g->ba" , List .of (shape ( 60 , 10 ), shape (6 , 5 ), shape (3 ))),
88- //
89- new Config ("ab,bc,cd,de->ae" , List .of (shape (5 , 60 ), shape (60 , 10 ), shape (10 , 5 ), shape (5 , 4 ))), // chain of mm
90- //
91- // new Config("ji,jz,zx->ix", List.of(shape(60, 5), shape( 60, 10), shape(10, 2))),
92- // new Config("fx,fg,fz,xg->z", List.of(shape(60, 5), shape( 60, 10), shape(60, 6), shape(5, 10))),
93- new Config ("fx,fg,fz,xg,zx,zg->g" , // each idx 3 times (cell tpl)
94- List .of (shape (5 , 60 ), shape (5 , 30 ), shape (5 , 10 ), shape (60 , 30 ), shape (10 , 60 ), shape (10 , 30 ))),
95- //
96- new Config ("i->" , List .of (shape (10 ))),
97- new Config ("i->i" , List .of (shape (10 ))),
98-
99- // test fused
100- new Config ("ij,ij,ji,i,j->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
101- new Config ("ij,ij,ji,j->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
102- new Config ("ij,ij,ji,i->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
103- new Config ("ij,i,j->ij" , List .of (shape (10 , 5 ), shape (10 ),shape (5 ))),
104- new Config ("ij,i,i->ij" , List .of (shape (10 , 5 ), shape (10 ),shape (10 )), List .of (0.01 ,0.02 ,0.1 )),
105- new Config ("ij,j,j->ij" , List .of (shape (10 , 5 ), shape (5 ),shape (5 ))),
106- new Config ("ij,i,j->i" , List .of (shape (10 , 5 ), shape (10 ),shape (5 ))),
107- new Config ("ij,i,i->i" , List .of (shape (10 , 5 ), shape (10 ),shape (10 ))),
108- // new Config("ij,j,j->i", List.of(shape(10, 5), shape(5),shape(5))),
109- new Config ("ij,i,j->j" , List .of (shape (10 , 5 ), shape (10 ),shape (5 ))),
110- // new Config("ij,i,i->j", List.of(shape(10, 5), shape(10),shape(10))),
111- // new Config("ij,j,j->j", List.of(shape(10, 5), shape(5),shape(5))),
112- new Config ("ij,i,j->" , List .of (shape (10 , 5 ), shape (10 ),shape (5 ))),
113- // new Config("ij,i,i->", List.of(shape(10, 5), shape(10),shape(10))),
114- // new Config("ij,j,j->", List.of(shape(10, 5), shape(5),shape(5))),
115-
116- // test fuesed:
117- new Config ("ij,ij,ji,i,j->i" , List .of (shape (7 , 5 ), shape (7 , 5 ),shape (5 , 7 ),shape (7 ),shape (5 ))),
118- new Config ("ij,i,i,j,j->i" , List .of (shape (7 , 50 ), shape (7 ),shape (7 ),shape (50 ),shape (50 ))),
119- new Config ("ij,i,i,j,j,z->i" , List .of (shape (7 , 50 ), shape (7 ),shape (7 ),shape (50 ),shape (50 ),shape (2 )),List .of (1.0 ,1.0 ,1.0 ,1.0 ,1.0 ,1.0 ) ), // include scalar to tmpl
120- new Config ("ij,ij,ij,i,j->j" , List .of (shape (5 , 5 ), shape (5 , 5 ), shape (5 , 5 ),shape (5 ),shape (5 ))),
121- new Config ("ij,ij,ij,i,j,iz->z" , List .of (shape (5 , 5 ), shape (5 , 5 ), shape (5 , 5 ),shape (5 ),shape (5 ),shape (5 , 60 ))),
122- // new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))),
123-
124-
125- new Config ("ij,i,j,iz->z" , List .of (shape (10 , 5 ),shape (10 ),shape (5 ),shape (10 , 51 ))),
126- new Config ("ij,i,j,iz->z" , List .of (shape (20 , 10 ),shape (20 ),shape (10 ),shape (20 , 10 ))),
127- new Config ("ij,i,j->j" , List .of (shape (100 , 5 ),shape (100 ),shape (5 ))),
128- new Config ("ij,ij,ji->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ))),
129- new Config ("ij,ij,ji,i,j->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
130- new Config ("ij,ij,ji->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ))),
131- new Config ("ij,ij,ji->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ))),
132- new Config ("ij,ij,ji->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ))),
133- new Config ("ij,ij,ji,j->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
134- new Config ("ij,ij,ji,i->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
135- new Config ("ij,ij,ji,i,j->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
136- new Config ("ij,ij,ji,j->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
137- new Config ("ij,ij,ji,i->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
138- new Config ("ij,ij,ji,i,j->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
139- new Config ("ij,ij,ji,j->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
140- new Config ("ij,ij,ji,i->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
141- new Config ("ij,ij,ji,i,j->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
142- new Config ("ij,ij,ji,j,i,ab,ba,ab,a,b->jb" , Map .of ('i' ,10 , 'j' ,5 , 'a' , 11 , 'b' , 6 )),
143- //skinny right:
144- new Config ("ij,ij,ji,j,i,iz->jz" , Map .of ('i' ,600 , 'j' ,10 ,'z' , 6 )), // with outer mm
145- // no skinny right:
146- new Config ("ij,ij,ji,j,i,iz->jz" , Map .of ('i' ,10 , 'j' ,10 ,'z' , 10 )), // with outer mm
147- new Config ("ij,ij,ji,j,i,iz->zj" , Map .of ('i' ,60 , 'j' ,10 ,'z' , 6 )) // with outer mm
148- ,new Config ("ij,ij,ij,jk->ik" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (10 , 5 ),shape (5 , 10 )))
149-
150- // ,new Config("ij,ij,ji->ij", List.of(shape(100, 50), shape(100, 50),shape(50, 100)), List.of(0.1,1.0,1.0))
79+ new Config ("ij->ji" , List .of (shape (10 , 5 ))), // T
80+ new Config ("ij->ij" , List .of (shape (10 , 5 ))),
81+ new Config ("i->i" , List .of (shape (10 ))),
82+ new Config ("ii->i" , List .of (shape (10 , 10 ))), // Diag
83+ new Config ("ii,i->i" , List .of (shape (10 , 10 ),shape (10 ))), // Diag*vec
84+
85+ new Config ("ab,cd->ba" , List .of (shape ( 6 , 10 ), shape (6 , 5 ))), // sum cd to scalar and multiply ab
15186
87+ new Config ("fx,fg,fz,xg,zx,zg->g" , // each idx 3 times (cell tpl fallback)
88+ List .of (shape (5 , 6 ), shape (5 , 3 ), shape (5 , 10 ), shape (6 , 3 ), shape (10 , 6 ), shape (10 , 3 ))),
89+
90+ // test fused:
91+ new Config ("ij,ij,ji->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ), shape (5 , 10 ))),
92+ new Config ("ij,ij,ji,i,j->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
93+ new Config ("ij,ij,ji,i->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
94+ new Config ("ij,ij,ji,j->ij" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
95+
96+ new Config ("ij,ij,ji->i" , List .of (shape (10 , 5 ), shape (10 , 5 ), shape (5 , 10 ))),
97+ new Config ("ij,ij,ji,i,j->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
98+ new Config ("ij,ij,ji,i->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
99+ new Config ("ij,ij,ji,j->i" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
100+
101+ new Config ("ij,ij,ji->j" , List .of (shape (10 , 5 ), shape (10 , 5 ), shape (5 , 10 ))),
102+ new Config ("ij,ij,ji,i,j->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
103+ new Config ("ij,ij,ji,i->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
104+ new Config ("ij,ij,ji,j->j" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
105+
106+ new Config ("ij,ij,ji->" , List .of (shape (10 , 5 ), shape (10 , 5 ), shape (5 , 10 ))),
107+ new Config ("ij,ij,ji,i,j->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ),shape (5 ))),
108+ new Config ("ij,ij,ji,i->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (10 ))),
109+ new Config ("ij,ij,ji,j->" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (5 , 10 ),shape (5 ))),
110+
111+ new Config ("ij,ij,ij,i,j,iz->z" , List .of (shape (5 , 5 ), shape (5 , 5 ), shape (5 , 5 ),shape (5 ),shape (5 ),shape (5 , 6 ))),
112+ new Config ("ij,ij,ij,i,j,iz,z->z" , List .of (shape (5 , 5 ), shape (5 , 5 ), shape (5 , 5 ),shape (5 ),shape (5 ),shape (5 , 6 ),shape (6 ))),
113+
114+ new Config ("ij,i,j,iz->z" , List .of (shape (10 , 5 ),shape (10 ),shape (5 ),shape (10 , 51 ))),
115+ new Config ("ij,i,j,iz->z" , List .of (shape (20 , 10 ),shape (20 ),shape (10 ),shape (20 , 10 ))),
116+
117+ new Config ("ij,ij,ji,j,i, ab,ba,ab,a,b->jb" , Map .of ('i' ,10 , 'j' ,5 , 'a' , 11 , 'b' , 6 )),
118+ new Config ("ij,ij,ji,j,i,iz->jz" , Map .of ('i' ,600 , 'j' ,10 ,'z' , 6 )), // //skinny right with outer mm
119+ new Config ("ij,ij,ji,j,i,iz->jz" , Map .of ('i' ,10 , 'j' ,10 ,'z' , 10 )), // // no skinny right
120+ new Config ("ij,ij,ji,j,i,iz->zj" , Map .of ('i' ,60 , 'j' ,10 ,'z' , 6 )),
121+ new Config ("ij,ij,ij,jk->ik" , List .of (shape (10 , 5 ), shape (10 , 5 ),shape (10 , 5 ),shape (5 , 10 )))
152122 );
153123 private final int id ;
154124 private final String einsumStr ;
@@ -229,7 +199,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar)
229199 sb .append ("\n " );
230200 }
231201 sb .append ("\n " );
232- // sb.append("for (i in 1:5) {\n");
233202 sb .append ("R = einsum(\" " );
234203 sb .append (config .einsumStr );
235204 sb .append ("\" , " );
@@ -242,7 +211,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar)
242211 sb .append ("A" );
243212 sb .append (config .shapes .size () - 1 );
244213 sb .append (")" );
245- // sb.append("\n}\n");
246214
247215 sb .append ("\n \n " );
248216 sb .append ("write(R, $1)\n " );
0 commit comments