|
1 | 1 | package org.apache.sysds.runtime.einsum; |
2 | 2 |
|
| 3 | +import org.apache.commons.logging.Log; |
| 4 | +import org.apache.sysds.runtime.functionobjects.Multiply; |
| 5 | +import org.apache.sysds.runtime.functionobjects.Plus; |
| 6 | +import org.apache.sysds.runtime.functionobjects.ReduceAll; |
| 7 | +import org.apache.sysds.runtime.functionobjects.ReduceCol; |
| 8 | +import org.apache.sysds.runtime.functionobjects.ReduceRow; |
| 9 | +import org.apache.sysds.runtime.functionobjects.SwapIndex; |
| 10 | +import org.apache.sysds.runtime.instructions.cp.DoubleObject; |
| 11 | +import org.apache.sysds.runtime.instructions.cp.ScalarObject; |
| 12 | +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; |
| 13 | +import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| 14 | +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; |
| 15 | +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; |
| 16 | +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; |
| 17 | +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; |
| 18 | +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; |
| 19 | + |
| 20 | +import java.util.ArrayList; |
| 21 | + |
| 22 | +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; |
| 23 | +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; |
| 24 | + |
3 | 25 | public class EOpNodeBinary extends EOpNode { |
| 26 | + |
4 | 27 | public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed |
5 | 28 | ////// summations: ////// |
6 | 29 | aB_a,// -> B |
@@ -31,13 +54,145 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b |
31 | 54 | AB_scalar, // m-scalar |
32 | 55 | scalar_scalar |
33 | 56 | } |
34 | | - public EOpNode left; |
35 | | - public EOpNode right; |
36 | | - public EBinaryOperand operand; |
| 57 | + public EOpNode _left; |
| 58 | + public EOpNode _right; |
| 59 | + public EBinaryOperand _operand; |
37 | 60 | public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ |
38 | 61 | super(c1,c2); |
39 | | - this.left = left; |
40 | | - this.right = right; |
41 | | - this.operand = operand; |
| 62 | + this._left = left; |
| 63 | + this._right = right; |
| 64 | + this._operand = operand; |
42 | 65 | } |
| 66 | + |
| 67 | + @Override |
| 68 | + public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, Log LOG) { |
| 69 | + EOpNodeBinary bin = this; |
| 70 | + MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG); |
| 71 | + MatrixBlock right = _right.computeEOpNode(inputs, numThreads, LOG); |
| 72 | + |
| 73 | + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); |
| 74 | + |
| 75 | + MatrixBlock res; |
| 76 | + |
| 77 | + LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin); |
| 78 | + |
| 79 | + switch (bin._operand){ |
| 80 | + case AB_AB -> { |
| 81 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 82 | + } |
| 83 | + case A_A -> { |
| 84 | + ensureMatrixBlockColumnVector(left); |
| 85 | + ensureMatrixBlockColumnVector(right); |
| 86 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 87 | + } |
| 88 | + case a_a -> { |
| 89 | + ensureMatrixBlockColumnVector(left); |
| 90 | + ensureMatrixBlockColumnVector(right); |
| 91 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 92 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); |
| 93 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 94 | + } |
| 95 | + //////////// |
| 96 | + case Ba_Ba -> { |
| 97 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 98 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); |
| 99 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 100 | + } |
| 101 | + case aB_aB -> { |
| 102 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 103 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); |
| 104 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 105 | + ensureMatrixBlockColumnVector(res); |
| 106 | + } |
| 107 | + case ab_ab -> { |
| 108 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 109 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); |
| 110 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 111 | + } |
| 112 | + case ab_ba -> { |
| 113 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 114 | + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 115 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 116 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); |
| 117 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 118 | + } |
| 119 | + case Ba_aB -> { |
| 120 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 121 | + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 122 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 123 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); |
| 124 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 125 | + } |
| 126 | + case aB_Ba -> { |
| 127 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 128 | + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 129 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 130 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); |
| 131 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 132 | + } |
| 133 | + |
| 134 | + ///////// |
| 135 | + case AB_BA -> { |
| 136 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 137 | + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 138 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); |
| 139 | + } |
| 140 | + case Ba_aC -> { |
| 141 | + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); |
| 142 | + } |
| 143 | + case aB_Ca -> { |
| 144 | + res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), numThreads); |
| 145 | + } |
| 146 | + case Ba_Ca -> { |
| 147 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 148 | + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 149 | + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); |
| 150 | + } |
| 151 | + case aB_aC -> { |
| 152 | + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); |
| 153 | + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); |
| 154 | + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); |
| 155 | + } |
| 156 | + case A_scalar, AB_scalar -> { |
| 157 | + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); |
| 158 | + } |
| 159 | + case BA_A -> { |
| 160 | + ensureMatrixBlockRowVector(right); |
| 161 | + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); |
| 162 | + } |
| 163 | + case Ba_a -> { |
| 164 | + ensureMatrixBlockRowVector(right); |
| 165 | + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); |
| 166 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); |
| 167 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 168 | + } |
| 169 | + |
| 170 | + case AB_A -> { |
| 171 | + ensureMatrixBlockColumnVector(right); |
| 172 | + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); |
| 173 | + } |
| 174 | + case aB_a -> { |
| 175 | + ensureMatrixBlockColumnVector(right); |
| 176 | + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); |
| 177 | + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); |
| 178 | + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); |
| 179 | + ensureMatrixBlockColumnVector(res); |
| 180 | + } |
| 181 | + |
| 182 | + case A_B -> { |
| 183 | + ensureMatrixBlockColumnVector(left); |
| 184 | + ensureMatrixBlockRowVector(right); |
| 185 | + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); |
| 186 | + } |
| 187 | + case scalar_scalar -> { |
| 188 | + return new MatrixBlock(left.get(0,0)*right.get(0,0)); |
| 189 | + } |
| 190 | + default -> { |
| 191 | + throw new IllegalArgumentException("Unexpected value: " + bin._operand.toString()); |
| 192 | + } |
| 193 | + |
| 194 | + } |
| 195 | + return res; |
| 196 | + } |
| 197 | + |
43 | 198 | } |
0 commit comments