|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
1 | 20 | package org.apache.sysds.runtime.einsum; |
2 | 21 |
|
3 | 22 | import org.apache.commons.logging.Log; |
|
25 | 44 |
|
26 | 45 | public class EOpNodeBinary extends EOpNode { |
27 | 46 |
|
28 | | - public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed |
| 47 | + |
| 48 | + public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed |
29 | 49 | ////// summations: ////// |
30 | 50 | aB_a,// -> B |
31 | 51 | Ba_a, // -> B |
@@ -58,14 +78,45 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b |
58 | 78 | public EOpNode _left; |
59 | 79 | public EOpNode _right; |
60 | 80 | public EBinaryOperand _operand; |
| 81 | + private boolean transposeResult; |
61 | 82 | public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ |
62 | 83 | super(c1,c2); |
63 | 84 | this._left = left; |
64 | 85 | this._right = right; |
65 | 86 | this._operand = operand; |
66 | 87 | } |
| 88 | + public void setTransposeResult(boolean transposeResult){ |
| 89 | + this.transposeResult = transposeResult; |
| 90 | + } |
67 | 91 |
|
68 | | - @Override |
| 92 | + public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) { |
| 93 | + if (left.c2 == right.c1) { return new EOpNodeBinary(left.c1, right.c2, left, right, EBinaryOperand.Ba_aC); } |
| 94 | + if (left.c2 == right.c2) { return new EOpNodeBinary(left.c1, right.c1, left, right, EBinaryOperand.Ba_Ca); } |
| 95 | + if (left.c1 == right.c1) { return new EOpNodeBinary(left.c2, right.c2, left, right, EBinaryOperand.aB_aC); } |
| 96 | + if (left.c1 == right.c2) { |
| 97 | + var res = new EOpNodeBinary(left.c2, right.c1, left, right, EBinaryOperand.aB_Ca); |
| 98 | + res.setTransposeResult(true); |
| 99 | + return res; |
| 100 | + } |
| 101 | + throw new RuntimeException("EOpNodeBinary::combineMatrixMultiply: invalid matrix operation"); |
| 102 | + } |
| 103 | + |
| 104 | + @Override |
| 105 | + public String[] recursivePrintString() { |
| 106 | + String[] left = _left.recursivePrintString(); |
| 107 | + String[] right = _right.recursivePrintString(); |
| 108 | + String[] res = new String[left.length + right.length+1]; |
| 109 | + res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString(); |
| 110 | + for (int i=0; i<left.length; i++) { |
| 111 | + res[i+1] = (i==0 ? "┌─ " : "| ") +left[i]; |
| 112 | + } |
| 113 | + for (int i=0; i<right.length; i++) { |
| 114 | + res[left.length+i+1] = (i==0 ? "└─ " : "| ") +right[i]; |
| 115 | + } |
| 116 | + return res; |
| 117 | + } |
| 118 | + |
| 119 | + @Override |
69 | 120 | public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, Log LOG) { |
70 | 121 | EOpNodeBinary bin = this; |
71 | 122 | MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG); |
@@ -204,6 +255,10 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, |
204 | 255 | } |
205 | 256 |
|
206 | 257 | } |
| 258 | + if(transposeResult){ |
| 259 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 260 | + res = res.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 261 | + } |
207 | 262 | return res; |
208 | 263 | } |
209 | 264 |
|
|
0 commit comments