Skip to content

Commit 5dbeed7

Browse files
remove comments
1 parent e063fc8 commit 5dbeed7

File tree

9 files changed

+298
-274
lines changed

9 files changed

+298
-274
lines changed

src/main/java/org/apache/sysds/hops/NaryOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ else if ( areDimsBelowThreshold() )
165165
setRequiresRecompileIfNecessary();
166166

167167
//ensure cp exec type for single-node operations
168-
if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST
168+
if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST || _op == OpOpN.EINSUM
169169
//TODO: cbind/rbind of lists only support in CP right now
170170
|| (_op == OpOpN.CBIND && getInput().get(0).getDataType().isList())
171171
|| (_op == OpOpN.RBIND && getInput().get(0).getDataType().isList())

src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import scala.Int;
2525

2626
import java.util.ArrayList;
27+
import java.util.Arrays;
28+
import java.util.List;
2729

2830
public abstract class EOpNode {
2931
public Character c1;
@@ -37,14 +39,28 @@ public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
3739
this.dim2 = dim2;
3840
}
3941

40-
@Override
41-
public String toString() {
42+
public String getOutputString() {
4243
if(c1 == null) return "''";
4344
if(c2 == null) return c1.toString();
4445
return c1.toString() + c2.toString();
4546
}
47+
public abstract List<EOpNode> getChildren();
4648

47-
public abstract String[] recursivePrintString();
49+
public String[] recursivePrintString(){
50+
ArrayList<String[]> inpStrings = new ArrayList<>();
51+
for (EOpNode node : getChildren()) {
52+
inpStrings.add(node.recursivePrintString());
53+
}
54+
String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new);
55+
String[] res = new String[1 + inpRes.length];
56+
57+
res[0] = this.toString();
58+
59+
for (int i=0; i<inpRes.length; i++) {
60+
res[i+1] = (i==0 ? "┌ " : (i==inpRes.length-1 ? "└ " : "| "))+inpRes[i];
61+
}
62+
return res;
63+
};
4864

4965
public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG);
5066

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@
4646

4747
public class EOpNodeBinary extends EOpNode {
4848

49-
public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
49+
public enum EBinaryOperand { // upper case: char remains, lower case: summed (reduced) dimension
5050
////// mm: //////
5151
Ba_aC, // -> BC
5252
aB_Ca, // -> CB
5353
Ba_Ca, // -> BC
5454
aB_aC, // -> BC
5555

56-
////// elementwisemult and sums //////
56+
////// element-wise multiplications and sums //////
5757
aB_aB,// elemwise and colsum -> B
5858
Ab_Ab, // elemwise and rowsum ->A
5959
Ab_bA, // elemwise, either colsum or rowsum -> A
@@ -169,18 +169,12 @@ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
169169
}
170170

171171
@Override
172-
public String[] recursivePrintString() {
173-
String[] left = this.left.recursivePrintString();
174-
String[] right = this.right.recursivePrintString();
175-
String[] res = new String[left.length + right.length+1];
176-
res[0] = this.getClass().getSimpleName()+" ("+ operand.toString()+") "+this.toString();
177-
for (int i=0; i<left.length; i++) {
178-
res[i+1] = (i==0 ? "┌─ " : " ") +left[i];
179-
}
180-
for (int i=0; i<right.length; i++) {
181-
res[left.length+i+1] = (i==0 ? "└─ " : " ") +right[i];
182-
}
183-
return res;
172+
public List<EOpNode> getChildren() {
173+
return List.of(this.left, this.right);
174+
}
175+
@Override
176+
public String toString() {
177+
return this.getClass().getSimpleName()+" ("+ operand.toString()+") "+getOutputString();
184178
}
185179

186180
@Override
@@ -193,8 +187,6 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
193187

194188
MatrixBlock res;
195189

196-
if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin.left +","+bin.right +"->"+bin);
197-
198190
switch (bin.operand){
199191
case AB_AB -> {
200192
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
@@ -212,22 +204,28 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
212204
res.getDenseBlockValues()[0] = LibMatrixMult.dotProduct(left.getDenseBlockValues(), right.getDenseBlockValues(), 0,0 , left.getNumRows());
213205
}
214206
case Ab_Ab -> {
215-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads);
207+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
208+
null, numThreads);
216209
}
217210
case aB_aB -> {
218-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads);
211+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
212+
null, numThreads);
219213
}
220214
case ab_ab -> {
221-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads);
215+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
216+
null, numThreads);
222217
}
223218
case ab_ba -> {
224-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads);
219+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
220+
null, numThreads);
225221
}
226222
case Ab_bA -> {
227-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads);
223+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
224+
null, numThreads);
228225
}
229226
case aB_Ba -> {
230-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads);
227+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
228+
null, numThreads);
231229
}
232230
case AB_BA -> {
233231
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
@@ -271,14 +269,16 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
271269
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
272270
}
273271
case Ab_b -> {
274-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),null,numThreads);
272+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),
273+
null, numThreads);
275274
}
276275
case AB_A -> {
277276
ensureMatrixBlockColumnVector(right);
278277
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
279278
}
280279
case aB_a -> {
281-
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),null,numThreads);
280+
res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),
281+
null, numThreads);
282282
}
283283
case A_B -> {
284284
ensureMatrixBlockColumnVector(left);
@@ -427,10 +427,6 @@ else if (n1.c2 == n2.c1) {
427427
return null; // AB_B
428428
}else{
429429
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2));
430-
// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
431-
// return null; // AB_B
432-
// }
433-
// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
434430
}
435431
}
436432
if(n1.c1 == n2.c2) {

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,22 @@
2323
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2424

2525
import java.util.ArrayList;
26+
import java.util.List;
2627

2728
public class EOpNodeData extends EOpNode {
2829
public int matrixIdx;
2930
public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, int matrixIdx){
3031
super(c1,c2,dim1,dim2);
3132
this.matrixIdx = matrixIdx;
3233
}
34+
35+
@Override
36+
public List<EOpNode> getChildren() {
37+
return List.of();
38+
}
3339
@Override
34-
public String[] recursivePrintString() {
35-
String[] res = new String[1];
36-
res[0] = this.getClass().getSimpleName()+" ("+matrixIdx+") "+this.toString();
37-
return res;
40+
public String toString() {
41+
return this.getClass().getSimpleName()+" ("+matrixIdx+") "+getOutputString();
3842
}
3943
@Override
4044
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG) {

0 commit comments

Comments
 (0)