Skip to content

Commit 911c796

Browse files
bugfixes and move code to other places
1 parent d64e371 commit 911c796

File tree

5 files changed

+184
-175
lines changed

5 files changed

+184
-175
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ protected void optimizeMMChain(Hop hop, List<Hop> mmChain, List<Hop> mmOperators
210210
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
211211
* Introduction to Algorithms, Third Edition, MIT Press, page 395.
212212
*/
213-
private static int[][] mmChainDP(double[] dimArray, int size)
213+
public static int[][] mmChainDP(double[] dimArray, int size)
214214
{
215215
double[][] dpMatrix = new double[size][size]; //min cost table
216216
int[][] split = new int[size][size]; //min cost index table

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
package org.apache.sysds.runtime.einsum;
2121

22+
import org.apache.commons.lang3.tuple.Pair;
23+
import org.apache.commons.lang3.tuple.Triple;
2224
import org.apache.commons.logging.Log;
2325
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
2426
import org.apache.sysds.runtime.functionobjects.Multiply;
@@ -38,6 +40,8 @@
3840
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
3941

4042
import java.util.ArrayList;
43+
import java.util.HashMap;
44+
import java.util.function.Predicate;
4145

4246
import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
4347
import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector;
@@ -281,4 +285,112 @@ public void reorderChildren(Character outChar1, Character outChar2) {
281285
}
282286
}
283287

288+
// used in old method
289+
public static Triple<Integer, EBinaryOperand, Pair<Character, Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character, Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character outChar1, Character outChar2){
290+
Predicate<Character> cannotBeSummed = (c) ->
291+
c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2;
292+
293+
if(n1.c1 == null) {
294+
// n2.c1 also has to be null
295+
return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null));
296+
}
297+
298+
if(n2.c1 == null) {
299+
if(n1.c2 == null)
300+
return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null));
301+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2));
302+
}
303+
304+
if(n1.c1 == n2.c1){
305+
if(n1.c2 != null){
306+
if ( n1.c2 == n2.c2){
307+
if( cannotBeSummed.test(n1.c1)){
308+
if(cannotBeSummed.test(n1.c2)){
309+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2));
310+
}
311+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null));
312+
}
313+
314+
if(cannotBeSummed.test(n1.c2)){
315+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null));
316+
}
317+
318+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null));
319+
320+
}
321+
322+
else if(n2.c2 == null){
323+
if(cannotBeSummed.test(n1.c1)){
324+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2));
325+
}
326+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2)
327+
}
328+
else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
329+
return null;// AB,AC
330+
}
331+
else {
332+
return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2
333+
}
334+
}else{ // n1.c2 = null -> c2.c2 = null
335+
if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){
336+
return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null));
337+
}
338+
return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null));
339+
}
340+
341+
342+
}else{ // n1.c1 != n2.c1
343+
if(n1.c2 == null) {
344+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1));
345+
}
346+
else if(n2.c2 == null) { // ab,c
347+
if (n1.c2 == n2.c1) {
348+
if(cannotBeSummed.test(n1.c2)){
349+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2));
350+
}
351+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
352+
}
353+
return null; // AB,C
354+
}
355+
else if (n1.c2 == n2.c1) {
356+
if(n1.c1 == n2.c2){ // ab,ba
357+
if(cannotBeSummed.test(n1.c1)){
358+
if(cannotBeSummed.test(n1.c2)){
359+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2));
360+
}
361+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null));
362+
}
363+
if(cannotBeSummed.test(n1.c2)){
364+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null));
365+
}
366+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null));
367+
}
368+
if(cannotBeSummed.test(n1.c2)){
369+
return null; // AB_B
370+
}else{
371+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2));
372+
// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
373+
// return null; // AB_B
374+
// }
375+
// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
376+
}
377+
}
378+
if(n1.c1 == n2.c2) {
379+
if(cannotBeSummed.test(n1.c1)){
380+
return null; // AB_B
381+
}
382+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult
383+
}
384+
else if (n1.c2 == n2.c2) {
385+
if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){
386+
return null; // BA_CA
387+
}else{
388+
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1
389+
}
390+
}
391+
else { // something like ab,cd
392+
return null;
393+
}
394+
}
395+
}
284396
}

src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ private void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, dou
108108
int bi = 0;
109109
double[] TMP1 = null;
110110
if (_ABCount != 0){
111+
if(_ABCount == 1 & _ACount == 0 && _BCount == 0){
112+
LibMatrixMult.vectMultiplyWrite(a, b[0].values(rix), c, ai, ai, ci, len);
113+
return;
114+
}
111115
TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len);
112116
while (bi < _ABCount) {
113117
if(_ACount == 0 && _BCount == 0 && bi == _ABCount-1) {

src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java

Lines changed: 13 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@
4343
import org.apache.sysds.utils.Explain;
4444

4545
import java.util.*;
46-
import java.util.function.Predicate;
4746
import java.util.stream.Collectors;
4847

4948
import static org.apache.sysds.api.DMLScript.EXPLAIN;
5049
import static org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP;
5150

5251
public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
5352
public static final boolean FORCE_CELL_TPL = false;
54-
public static final boolean FUSED = true;
53+
// public static final boolean FUSED = true;
5554
public static final boolean FUSE_OUTER_MULTIPLY = true;
5655

5756

@@ -132,10 +131,10 @@ public void processInstruction(ExecutionContext ec) {
132131
ArrayList<MatrixBlock> remainingMatrices;
133132

134133
if(!FORCE_CELL_TPL) {
135-
if(true){
134+
if(true){ // new way: search for fusions and matrix-multiplications chain in a loop
136135
plan = generatePlanFusionAndMM(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2);
137136
}else { // old way: try to do fusion first and then rest in binary fashion cost based
138-
if(FUSED) {
137+
if(true /*FUSED*/) {
139138
ret = new ArrayList<>();
140139
EOpNodeFuse fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2,
141140
einc.charToDimensionSize, characterToOccurences, ret);
@@ -152,10 +151,9 @@ public void processInstruction(ExecutionContext ec) {
152151
eOpNodes = ret;
153152
}
154153
}
155-
156154
}
157155

158-
Pair<Integer, List<EOpNode>> costAndPlan = generatePlanBinaryCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences,
156+
Pair<Integer, List<EOpNode>> costAndPlan = generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences,
159157
einc.outChar1, einc.outChar2);
160158
plan = costAndPlan.getRight();
161159
}
@@ -191,6 +189,7 @@ public void processInstruction(ExecutionContext ec) {
191189
plan.set(0, new EOpNodeBinary(plan.get(0).c1, plan.get(1).c1, plan.get(0), plan.get(1), EBinaryOperand.A_B));
192190
if (plan.get(0).c1 == einc.outChar2 && plan.get(1).c1 == einc.outChar1)
193191
plan.set(0, new EOpNodeBinary(plan.get(1).c1, plan.get(0).c1, plan.get(1), plan.get(0), EBinaryOperand.A_B));
192+
plan.remove(1);
194193
}
195194
if (EXPLAIN != Explain.ExplainType.NONE )
196195
System.out.println("Einsum plan:");
@@ -224,6 +223,7 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){
224223
}
225224
}else if (einc.outChar1 != null){
226225
if(resNode.c1 == einc.outChar1 && resNode.c2 == null){
226+
ensureMatrixBlockColumnVector(remainingMatrices.get(0));
227227
ec.setMatrixOutput(output.getName(), remainingMatrices.get(0));
228228
}else{
229229
if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
@@ -255,6 +255,9 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){
255255

256256
MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSize, summingChars, einc.outRows, einc.outCols);
257257

258+
if (einc.outChar2 == null)
259+
ensureMatrixBlockColumnVector(res);
260+
258261
if (einc.outRows == 1 && einc.outCols == 1)
259262
ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0)));
260263
else ec.setMatrixOutput(output.getName(), res);
@@ -422,7 +425,6 @@ private static List<EOpNode> generatePlanFusionAndMM(ArrayList<EOpNode> eOpNodes
422425
lastNumOfOperands = eOpNodes.size();
423426

424427
EOpNodeFuse fuse = null;
425-
426428
do {
427429
ret = new ArrayList<>();
428430
fuse = EOpNodeFuse.match(eOpNodes, outChar1, outChar2, charToSizeMap, charToOccurences, ret);
@@ -445,7 +447,6 @@ private static List<EOpNode> generatePlanFusionAndMM(ArrayList<EOpNode> eOpNodes
445447
ret.add(bin);
446448
}
447449
eOpNodes = ret;
448-
449450
}
450451

451452
return eOpNodes;
@@ -596,15 +597,15 @@ private static ArrayList<List<EOpNode>> findMatrixMultiplicationChains(ArrayList
596597
}
597598

598599
// old way
599-
private Pair<Integer, List<EOpNode>> generatePlanBinaryCostBased(int cost, ArrayList<EOpNode> operands, HashMap<Character, Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character outChar1, Character outChar2) {
600+
private Pair<Integer, List<EOpNode>> generateBinaryPlanCostBased(int cost, ArrayList<EOpNode> operands, HashMap<Character, Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character outChar1, Character outChar2) {
600601
Integer minCost = cost;
601602
List<EOpNode> minNodes = operands;
602603

603604
if (operands.size() == 2){
604605
boolean swap = (operands.get(0).c2 == null && operands.get(1).c2 != null) || operands.get(0).c1 == null;
605606
EOpNode n1 = operands.get(!swap ? 0 : 1);
606607
EOpNode n2 = operands.get(!swap ? 1 : 0);
607-
Triple<Integer, EBinaryOperand, Pair<Character, Character>> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2);
608+
Triple<Integer, EBinaryOperand, Pair<Character, Character>> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2);
608609
if (t != null) {
609610
EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle());
610611
int thisCost = cost + t.getLeft();
@@ -625,7 +626,7 @@ else if (operands.size() == 1){
625626
EOpNode n2 = operands.get(!swap ? j : i);
626627

627628

628-
Triple<Integer, EBinaryOperand, Pair<Character, Character>> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2);
629+
Triple<Integer, EBinaryOperand, Pair<Character, Character>> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2);
629630
if (t != null){
630631
EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle());
631632
int thisCost = cost + t.getLeft();
@@ -644,7 +645,7 @@ else if (operands.size() == 1){
644645
}
645646
newOperands.add(newNode);
646647

647-
Pair<Integer, List<EOpNode>> furtherPlan = generatePlanBinaryCostBased(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2);
648+
Pair<Integer, List<EOpNode>> furtherPlan = generateBinaryPlanCostBased(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2);
648649
if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){
649650
minCost = furtherPlan.getLeft();
650651
minNodes = furtherPlan.getRight();
@@ -663,114 +664,6 @@ else if (operands.size() == 1){
663664
return Pair.of(minCost, minNodes);
664665
}
665666

666-
private static Triple<Integer, EBinaryOperand, Pair<Character, Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character, Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character outChar1, Character outChar2){
667-
Predicate<Character> cannotBeSummed = (c) ->
668-
c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2;
669-
670-
if(n1.c1 == null) {
671-
// n2.c1 also has to be null
672-
return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null));
673-
}
674-
675-
if(n2.c1 == null) {
676-
if(n1.c2 == null)
677-
return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null));
678-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2));
679-
}
680-
681-
if(n1.c1 == n2.c1){
682-
if(n1.c2 != null){
683-
if ( n1.c2 == n2.c2){
684-
if( cannotBeSummed.test(n1.c1)){
685-
if(cannotBeSummed.test(n1.c2)){
686-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2));
687-
}
688-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null));
689-
}
690-
691-
if(cannotBeSummed.test(n1.c2)){
692-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null));
693-
}
694-
695-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null));
696-
697-
}
698-
699-
else if(n2.c2 == null){
700-
if(cannotBeSummed.test(n1.c1)){
701-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2));
702-
}
703-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2)
704-
}
705-
else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
706-
return null;// AB,AC
707-
}
708-
else {
709-
return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2
710-
}
711-
}else{ // n1.c2 = null -> c2.c2 = null
712-
if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){
713-
return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null));
714-
}
715-
return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null));
716-
}
717-
718-
719-
}else{ // n1.c1 != n2.c1
720-
if(n1.c2 == null) {
721-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1));
722-
}
723-
else if(n2.c2 == null) { // ab,c
724-
if (n1.c2 == n2.c1) {
725-
if(cannotBeSummed.test(n1.c2)){
726-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2));
727-
}
728-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
729-
}
730-
return null; // AB,C
731-
}
732-
else if (n1.c2 == n2.c1) {
733-
if(n1.c1 == n2.c2){ // ab,ba
734-
if(cannotBeSummed.test(n1.c1)){
735-
if(cannotBeSummed.test(n1.c2)){
736-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2));
737-
}
738-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null));
739-
}
740-
if(cannotBeSummed.test(n1.c2)){
741-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null));
742-
}
743-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null));
744-
}
745-
if(cannotBeSummed.test(n1.c2)){
746-
return null; // AB_B
747-
}else{
748-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2));
749-
// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
750-
// return null; // AB_B
751-
// }
752-
// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
753-
}
754-
}
755-
if(n1.c1 == n2.c2) {
756-
if(cannotBeSummed.test(n1.c1)){
757-
return null; // AB_B
758-
}
759-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult
760-
}
761-
else if (n1.c2 == n2.c2) {
762-
if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){
763-
return null; // BA_CA
764-
}else{
765-
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1
766-
}
767-
}
768-
else { // we have something like ab,cd
769-
return null;
770-
}
771-
}
772-
}
773-
774667
private ArrayList<MatrixBlock> executePlan(List<EOpNode> plan, ArrayList<MatrixBlock> inputs) {
775668
ArrayList<MatrixBlock> res = new ArrayList<>(plan.size());
776669
for(EOpNode p : plan){

0 commit comments

Comments
 (0)