Skip to content

Commit d64e371

Browse files
new approach to fuse
1 parent 41f12e3 commit d64e371

File tree

9 files changed

+1105
-468
lines changed

9 files changed

+1105
-468
lines changed

src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package org.apache.sysds.runtime.einsum;
221

322
import org.apache.commons.logging.Log;
@@ -15,12 +34,13 @@ public EOpNode(Character c1, Character c2){
1534

1635
@Override
1736
public String toString() {
18-
if(c1 == null) return "-";
19-
37+
if(c1 == null) return "''";
2038
if(c2 == null) return c1.toString();
2139
return c1.toString() + c2.toString();
2240
}
2341

42+
public abstract String[] recursivePrintString();
43+
2444
public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG);
2545

2646
public abstract void reorderChildren(Character outChar1, Character outChar2);

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package org.apache.sysds.runtime.einsum;
221

322
import org.apache.commons.logging.Log;
@@ -25,7 +44,8 @@
2544

2645
public class EOpNodeBinary extends EOpNode {
2746

28-
public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
47+
48+
public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
2949
////// summations: //////
3050
aB_a,// -> B
3151
Ba_a, // -> B
@@ -58,14 +78,45 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b
5878
public EOpNode _left;
5979
public EOpNode _right;
6080
public EBinaryOperand _operand;
81+
private boolean transposeResult;
6182
public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){
6283
super(c1,c2);
6384
this._left = left;
6485
this._right = right;
6586
this._operand = operand;
6687
}
88+
public void setTransposeResult(boolean transposeResult){
89+
this.transposeResult = transposeResult;
90+
}
6791

68-
@Override
92+
public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
93+
if (left.c2 == right.c1) { return new EOpNodeBinary(left.c1, right.c2, left, right, EBinaryOperand.Ba_aC); }
94+
if (left.c2 == right.c2) { return new EOpNodeBinary(left.c1, right.c1, left, right, EBinaryOperand.Ba_Ca); }
95+
if (left.c1 == right.c1) { return new EOpNodeBinary(left.c2, right.c2, left, right, EBinaryOperand.aB_aC); }
96+
if (left.c1 == right.c2) {
97+
var res = new EOpNodeBinary(left.c2, right.c1, left, right, EBinaryOperand.aB_Ca);
98+
res.setTransposeResult(true);
99+
return res;
100+
}
101+
throw new RuntimeException("EOpNodeBinary::combineMatrixMultiply: invalid matrix operation");
102+
}
103+
104+
@Override
105+
public String[] recursivePrintString() {
106+
String[] left = _left.recursivePrintString();
107+
String[] right = _right.recursivePrintString();
108+
String[] res = new String[left.length + right.length+1];
109+
res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString();
110+
for (int i=0; i<left.length; i++) {
111+
res[i+1] = (i==0 ? "┌─ " : "| ") +left[i];
112+
}
113+
for (int i=0; i<right.length; i++) {
114+
res[left.length+i+1] = (i==0 ? "└─ " : "| ") +right[i];
115+
}
116+
return res;
117+
}
118+
119+
@Override
69120
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, Log LOG) {
70121
EOpNodeBinary bin = this;
71122
MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG);
@@ -204,6 +255,10 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
204255
}
205256

206257
}
258+
if(transposeResult){
259+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
260+
res = res.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
261+
}
207262
return res;
208263
}
209264

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package org.apache.sysds.runtime.einsum;
221

322
import org.apache.commons.logging.Log;
@@ -11,7 +30,12 @@ public EOpNodeData(Character c1, Character c2, int matrixIdx){
1130
super(c1,c2);
1231
this.matrixIdx = matrixIdx;
1332
}
14-
33+
@Override
34+
public String[] recursivePrintString() {
35+
String[] res = new String[1];
36+
res[0] = this.getClass().getSimpleName()+" ("+matrixIdx+") "+this.toString();
37+
return res;
38+
}
1539
@Override
1640
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG) {
1741
return inputs.get(matrixIdx);

0 commit comments

Comments
 (0)