Skip to content

Commit 3741895

Browse files
HubertKrawczykmboehm7
authored andcommitted
[SYSTEMDS-3909] New einsum expression evaluation framework
Closes #2312. Closes #2265.
1 parent b1c5d64 commit 3741895

File tree

19 files changed

+1640
-8
lines changed

19 files changed

+1640
-8
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ public enum Builtins {
403403
UNDER_SAMPLING("underSampling", true),
404404
UNIQUE("unique", false, true),
405405
UPPER_TRI("upper.tri", false, true),
406+
EINSUM("einsum", false, false),
406407
XDUMMY1("xdummy1", true), //error handling test
407408
XDUMMY2("xdummy2", true); //error handling test
408409

src/main/java/org/apache/sysds/common/InstructionType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ public enum InstructionType {
6262
PMMJ,
6363
MMChain,
6464
Union,
65+
EINSUM,
6566

6667
//SP Types
6768
MAPMM,

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ public enum Opcodes {
174174
RBIND("rbind", InstructionType.BuiltinNary),
175175
EVAL("eval", InstructionType.BuiltinNary),
176176
LIST("list", InstructionType.BuiltinNary),
177+
EINSUM("einsum", InstructionType.BuiltinNary),
177178

178179
//Parametrized builtin functions
179180
AUTODIFF("autoDiff", InstructionType.ParameterizedBuiltin),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,8 @@ public String toString() {
767767

768768
/** Operations that require a variable number of operands*/
769769
public enum OpOpN {
770-
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
771-
770+
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST, EINSUM;
771+
772772
public boolean isCellOp() {
773773
return this == MIN || this == MAX || this == PLUS || this == MULT;
774774
}

src/main/java/org/apache/sysds/hops/NaryOp.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.sysds.lops.Lop;
2727
import org.apache.sysds.common.Types.ExecType;
2828
import org.apache.sysds.lops.Nary;
29+
import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
2930
import org.apache.sysds.runtime.meta.DataCharacteristics;
3031
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
3132

@@ -235,6 +236,14 @@ public void refreshSizeInformation() {
235236
setDim1(getInput().size());
236237
setDim2(1);
237238
break;
239+
case EINSUM:
240+
String equationString = ((LiteralOp) _input.get(0)).getStringValue();
241+
var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, this.getInput().subList(1, this.getInput().size()));
242+
243+
setDim1(dims.getLeft());
244+
setDim2(dims.getMiddle());
245+
setDataType(dims.getRight());
246+
break;
238247
case PRINTF:
239248
case EVAL:
240249
//do nothing:

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
public class CNodeCell extends CNodeTpl
3333
{
34-
protected static final String JAVA_TEMPLATE =
34+
public static final String JAVA_TEMPLATE =
3535
"package codegen;\n"
3636
+ "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
3737
+ "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n"

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ public CNodeData(CNodeData node, String newName) {
5656
_cols = node.getNumCols();
5757
_dataType = node.getDataType();
5858
}
59+
60+
public CNodeData(String name, long hopID, long rows, long cols, DataType dataType) {
61+
_name = name;
62+
_hopID = hopID;
63+
_rows = rows;
64+
_cols = cols;
65+
_dataType = dataType;
66+
}
5967

6068
@Override
6169
public String getVarname() {

src/main/java/org/apache/sysds/lops/Nary.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ private String getOpcode() {
111111
case RBIND:
112112
case EVAL:
113113
case LIST:
114+
case EINSUM:
114115
return operationType.name().toLowerCase();
115116
case MIN:
116117
case MAX:

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Arrays;
2424
import java.util.HashMap;
2525
import java.util.HashSet;
26+
import java.util.LinkedList;
2627

2728
import org.antlr.v4.runtime.ParserRuleContext;
2829
import org.apache.commons.lang3.ArrayUtils;
@@ -35,6 +36,7 @@
3536
import org.apache.sysds.conf.ConfigurationManager;
3637
import org.apache.sysds.hops.OptimizerUtils;
3738
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
39+
import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
3840
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
3941
import org.apache.sysds.runtime.util.DnnUtils;
4042
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -751,7 +753,9 @@ else if(((ConstIdentifier) getThirdExpr().getOutput())
751753
else
752754
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
753755
break;
754-
756+
case EINSUM:
757+
validateEinsum((DataIdentifier) getOutputs()[0]);
758+
break;
755759
default: //always unconditional
756760
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);
757761
}
@@ -2063,7 +2067,9 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
20632067
output.setValueType(ValueType.INT64);
20642068
output.setNnz(id.getDim2());
20652069
break;
2066-
2070+
case EINSUM:
2071+
validateEinsum(output);
2072+
break;
20672073
default:
20682074
if( isMathFunction() ) {
20692075
checkMathFunctionParam();
@@ -2096,6 +2102,49 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
20962102
}
20972103
}
20982104

2105+
private void validateEinsum(DataIdentifier output){
2106+
if(getSecondExpr() == null)
2107+
raiseValidateError("Einsum: at least one input matrix required", false,
2108+
LanguageErrorCodes.INVALID_PARAMETERS);
2109+
2110+
if(!(getFirstExpr() instanceof StringIdentifier))
2111+
raiseValidateError("Einsum: first argument has to be equation str", false,
2112+
LanguageErrorCodes.INVALID_PARAMETERS);
2113+
2114+
String equationString = ((StringIdentifier)getFirstExpr()).getValue();
2115+
2116+
if (equationString.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS);
2117+
if (equationString.charAt(0) == '-' || equationString.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS);
2118+
2119+
Expression[] expressions = getAllExpr();
2120+
boolean allDimsKnown = true;
2121+
2122+
LinkedList<Identifier> matrixBlocks = new LinkedList<>();
2123+
for (int i=1;i<expressions.length; i++){
2124+
checkMatrixParam(expressions[i]);
2125+
if(!(expressions[i]).getOutput().dimsKnown()){
2126+
allDimsKnown = false;
2127+
break;
2128+
}
2129+
matrixBlocks.add((expressions[i].getOutput()));
2130+
}
2131+
2132+
if(allDimsKnown){
2133+
var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, matrixBlocks);
2134+
2135+
output.setDataType(dims.getRight());
2136+
output.setDimensions(dims.getLeft(), dims.getMiddle());
2137+
}else{
2138+
DataType dataType = EinsumEquationValidator.validateEinsumEquationNoDimensions(equationString, _args.length - 1);
2139+
2140+
output.setDataType(dataType);
2141+
output.setDimensions(-1l, -1l);
2142+
}
2143+
2144+
output.setValueType(ValueType.FP64);
2145+
output.setBlocksize(getSecondExpr().getOutput().getBlocksize());
2146+
}
2147+
20992148
private void setBinaryOutputProperties(DataIdentifier output) {
21002149
DataType dt1 = getFirstExpr().getOutput().getDataType();
21012150
DataType dt2 = getSecondExpr().getOutput().getDataType();

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2447,7 +2447,10 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D
24472447
new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
24482448
OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops));
24492449
break;
2450-
2450+
case EINSUM:
2451+
currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
2452+
OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops));
2453+
break;
24512454
case PPRED:
24522455
String sop = ((StringIdentifier)source.getThirdExpr()).getValue();
24532456
sop = sop.replace("\"", "");

0 commit comments

Comments
 (0)