Skip to content

Commit eac7e0e

Browse files
HubertKrawczykmboehm7
authored andcommitted
[SYSTEMDS-3909] Final framework for einsum expressions
Closes #2391.
1 parent 95270e9 commit eac7e0e

File tree

15 files changed

+2153
-675
lines changed

15 files changed

+2153
-675
lines changed

src/main/java/org/apache/sysds/hops/NaryOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ else if ( areDimsBelowThreshold() )
165165
setRequiresRecompileIfNecessary();
166166

167167
//ensure cp exec type for single-node operations
168-
if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST
168+
if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST || _op == OpOpN.EINSUM
169169
//TODO: cbind/rbind of lists only support in CP right now
170170
|| (_op == OpOpN.CBIND && getInput().get(0).getDataType().isList())
171171
|| (_op == OpOpN.RBIND && getInput().get(0).getDataType().isList())

src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ protected void optimizeMMChain(Hop hop, List<Hop> mmChain, List<Hop> mmOperators
210210
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
211211
* Introduction to Algorithms, Third Edition, MIT Press, page 395.
212212
*/
213-
private static int[][] mmChainDP(double[] dimArray, int size)
213+
public static int[][] mmChainDP(double[] dimArray, int size)
214214
{
215215
double[][] dpMatrix = new double[size][size]; //min cost table
216216
int[][] split = new int[size][size]; //min cost index table
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.einsum;
21+
22+
import org.apache.commons.logging.Log;
23+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
24+
import scala.Int;
25+
26+
import java.util.ArrayList;
27+
import java.util.Arrays;
28+
import java.util.List;
29+
30+
public abstract class EOpNode {
31+
public Character c1;
32+
public Character c2;
33+
public Integer dim1;
34+
public Integer dim2;
35+
public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
36+
this.c1 = c1;
37+
this.c2 = c2;
38+
this.dim1 = dim1;
39+
this.dim2 = dim2;
40+
}
41+
42+
public String getOutputString() {
43+
if(c1 == null) return "''";
44+
if(c2 == null) return c1.toString();
45+
return c1.toString() + c2.toString();
46+
}
47+
public abstract List<EOpNode> getChildren();
48+
49+
public String[] recursivePrintString(){
50+
ArrayList<String[]> inpStrings = new ArrayList<>();
51+
for (EOpNode node : getChildren()) {
52+
inpStrings.add(node.recursivePrintString());
53+
}
54+
String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new);
55+
String[] res = new String[1 + inpRes.length];
56+
57+
res[0] = this.toString();
58+
59+
for (int i=0; i<inpRes.length; i++) {
60+
res[i+1] = (i==0 ? "┌ " : (i==inpRes.length-1 ? "└ " : "| "))+inpRes[i];
61+
}
62+
return res;
63+
};
64+
65+
public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG);
66+
67+
public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2);
68+
}
69+

0 commit comments

Comments
 (0)