|
23 | 23 | import java.util.Arrays; |
24 | 24 | import java.util.HashMap; |
25 | 25 | import java.util.HashSet; |
| 26 | +import java.util.LinkedList; |
26 | 27 |
|
27 | 28 | import org.antlr.v4.runtime.ParserRuleContext; |
28 | 29 | import org.apache.commons.lang3.ArrayUtils; |
|
35 | 36 | import org.apache.sysds.conf.ConfigurationManager; |
36 | 37 | import org.apache.sysds.hops.OptimizerUtils; |
37 | 38 | import org.apache.sysds.parser.LanguageException.LanguageErrorCodes; |
| 39 | +import org.apache.sysds.runtime.einsum.EinsumEquationValidator; |
38 | 40 | import org.apache.sysds.runtime.meta.MatrixCharacteristics; |
39 | 41 | import org.apache.sysds.runtime.util.DnnUtils; |
40 | 42 | import org.apache.sysds.runtime.util.UtilFunctions; |
@@ -751,7 +753,9 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) |
751 | 753 | else |
752 | 754 | raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); |
753 | 755 | break; |
754 | | - |
| 756 | + case EINSUM: |
| 757 | + validateEinsum((DataIdentifier) getOutputs()[0]); |
| 758 | + break; |
755 | 759 | default: //always unconditional |
756 | 760 | raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); |
757 | 761 | } |
@@ -2063,7 +2067,9 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV |
2063 | 2067 | output.setValueType(ValueType.INT64); |
2064 | 2068 | output.setNnz(id.getDim2()); |
2065 | 2069 | break; |
2066 | | - |
| 2070 | + case EINSUM: |
| 2071 | + validateEinsum(output); |
| 2072 | + break; |
2067 | 2073 | default: |
2068 | 2074 | if( isMathFunction() ) { |
2069 | 2075 | checkMathFunctionParam(); |
@@ -2096,6 +2102,49 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV |
2096 | 2102 | } |
2097 | 2103 | } |
2098 | 2104 |
|
| 2105 | + private void validateEinsum(DataIdentifier output){ |
| 2106 | + if(getSecondExpr() == null) |
| 2107 | + raiseValidateError("Einsum: at least one input matrix required", false, |
| 2108 | + LanguageErrorCodes.INVALID_PARAMETERS); |
| 2109 | + |
| 2110 | + if(!(getFirstExpr() instanceof StringIdentifier)) |
| 2111 | + raiseValidateError("Einsum: first argument has to be equation str", false, |
| 2112 | + LanguageErrorCodes.INVALID_PARAMETERS); |
| 2113 | + |
| 2114 | + String equationString = ((StringIdentifier)getFirstExpr()).getValue(); |
| 2115 | + |
| 2116 | + if (equationString.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS); |
| 2117 | + if (equationString.charAt(0) == '-' || equationString.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS); |
| 2118 | + |
| 2119 | + Expression[] expressions = getAllExpr(); |
| 2120 | + boolean allDimsKnown = true; |
| 2121 | + |
| 2122 | + LinkedList<Identifier> matrixBlocks = new LinkedList<>(); |
| 2123 | + for (int i=1;i<expressions.length; i++){ |
| 2124 | + checkMatrixParam(expressions[i]); |
| 2125 | + if(!(expressions[i]).getOutput().dimsKnown()){ |
| 2126 | + allDimsKnown = false; |
| 2127 | + break; |
| 2128 | + } |
| 2129 | + matrixBlocks.add((expressions[i].getOutput())); |
| 2130 | + } |
| 2131 | + |
| 2132 | + if(allDimsKnown){ |
| 2133 | + var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, matrixBlocks); |
| 2134 | + |
| 2135 | + output.setDataType(dims.getRight()); |
| 2136 | + output.setDimensions(dims.getLeft(), dims.getMiddle()); |
| 2137 | + }else{ |
| 2138 | + DataType dataType = EinsumEquationValidator.validateEinsumEquationNoDimensions(equationString, _args.length - 1); |
| 2139 | + |
| 2140 | + output.setDataType(dataType); |
| 2141 | + output.setDimensions(-1l, -1l); |
| 2142 | + } |
| 2143 | + |
| 2144 | + output.setValueType(ValueType.FP64); |
| 2145 | + output.setBlocksize(getSecondExpr().getOutput().getBlocksize()); |
| 2146 | + } |
| 2147 | + |
2099 | 2148 | private void setBinaryOutputProperties(DataIdentifier output) { |
2100 | 2149 | DataType dt1 = getFirstExpr().getOutput().getDataType(); |
2101 | 2150 | DataType dt2 = getSecondExpr().getOutput().getDataType(); |
|
0 commit comments