Skip to content

Commit f740c1f

Browse files
move eopnode compute impl. to respective classes
1 parent 33fdd1b commit f740c1f

File tree

6 files changed

+319
-299
lines changed

6 files changed

+319
-299
lines changed

src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package org.apache.sysds.runtime.einsum;
22

3+
import org.apache.commons.logging.Log;
4+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
5+
6+
import java.util.ArrayList;
7+
38
public abstract class EOpNode {
49
public Character c1;
510
public Character c2; // nullable
@@ -15,5 +20,7 @@ public String toString() {
1520
if(c2 == null) return c1.toString();
1621
return c1.toString() + c2.toString();
1722
}
23+
24+
public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG);
1825
}
1926

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 161 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
11
package org.apache.sysds.runtime.einsum;
22

3+
import org.apache.commons.logging.Log;
4+
import org.apache.sysds.runtime.functionobjects.Multiply;
5+
import org.apache.sysds.runtime.functionobjects.Plus;
6+
import org.apache.sysds.runtime.functionobjects.ReduceAll;
7+
import org.apache.sysds.runtime.functionobjects.ReduceCol;
8+
import org.apache.sysds.runtime.functionobjects.ReduceRow;
9+
import org.apache.sysds.runtime.functionobjects.SwapIndex;
10+
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
11+
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
12+
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
13+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
14+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
15+
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
16+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
17+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
18+
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
19+
20+
import java.util.ArrayList;
21+
22+
import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
23+
import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector;
24+
325
public class EOpNodeBinary extends EOpNode {
26+
427
public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
528
////// summations: //////
629
aB_a,// -> B
@@ -31,13 +54,145 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b
3154
AB_scalar, // m-scalar
3255
scalar_scalar
3356
}
34-
public EOpNode left;
35-
public EOpNode right;
36-
public EBinaryOperand operand;
57+
public EOpNode _left;
58+
public EOpNode _right;
59+
public EBinaryOperand _operand;
3760
public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){
3861
super(c1,c2);
39-
this.left = left;
40-
this.right = right;
41-
this.operand = operand;
62+
this._left = left;
63+
this._right = right;
64+
this._operand = operand;
4265
}
66+
67+
@Override
68+
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, Log LOG) {
69+
EOpNodeBinary bin = this;
70+
MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG);
71+
MatrixBlock right = _right.computeEOpNode(inputs, numThreads, LOG);
72+
73+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
74+
75+
MatrixBlock res;
76+
77+
LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin);
78+
79+
switch (bin._operand){
80+
case AB_AB -> {
81+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
82+
}
83+
case A_A -> {
84+
ensureMatrixBlockColumnVector(left);
85+
ensureMatrixBlockColumnVector(right);
86+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
87+
}
88+
case a_a -> {
89+
ensureMatrixBlockColumnVector(left);
90+
ensureMatrixBlockColumnVector(right);
91+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
92+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
93+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
94+
}
95+
////////////
96+
case Ba_Ba -> {
97+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
98+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
99+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
100+
}
101+
case aB_aB -> {
102+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
103+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
104+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
105+
ensureMatrixBlockColumnVector(res);
106+
}
107+
case ab_ab -> {
108+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
109+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
110+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
111+
}
112+
case ab_ba -> {
113+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
114+
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
115+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
116+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
117+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
118+
}
119+
case Ba_aB -> {
120+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
121+
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
122+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
123+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
124+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
125+
}
126+
case aB_Ba -> {
127+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
128+
left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
129+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
130+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
131+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
132+
}
133+
134+
/////////
135+
case AB_BA -> {
136+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
137+
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
138+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
139+
}
140+
case Ba_aC -> {
141+
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads);
142+
}
143+
case aB_Ca -> {
144+
res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), numThreads);
145+
}
146+
case Ba_Ca -> {
147+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
148+
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
149+
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads);
150+
}
151+
case aB_aC -> {
152+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
153+
left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
154+
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads);
155+
}
156+
case A_scalar, AB_scalar -> {
157+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
158+
}
159+
case BA_A -> {
160+
ensureMatrixBlockRowVector(right);
161+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
162+
}
163+
case Ba_a -> {
164+
ensureMatrixBlockRowVector(right);
165+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
166+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
167+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
168+
}
169+
170+
case AB_A -> {
171+
ensureMatrixBlockColumnVector(right);
172+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
173+
}
174+
case aB_a -> {
175+
ensureMatrixBlockColumnVector(right);
176+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
177+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
178+
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
179+
ensureMatrixBlockColumnVector(res);
180+
}
181+
182+
case A_B -> {
183+
ensureMatrixBlockColumnVector(left);
184+
ensureMatrixBlockRowVector(right);
185+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
186+
}
187+
case scalar_scalar -> {
188+
return new MatrixBlock(left.get(0,0)*right.get(0,0));
189+
}
190+
default -> {
191+
throw new IllegalArgumentException("Unexpected value: " + bin._operand.toString());
192+
}
193+
194+
}
195+
return res;
196+
}
197+
43198
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
package org.apache.sysds.runtime.einsum;
22

3+
import org.apache.commons.logging.Log;
4+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
5+
6+
import java.util.ArrayList;
7+
38
public class EOpNodeData extends EOpNode {
49
public int matrixIdx;
510
public EOpNodeData(Character c1, Character c2, int matrixIdx){
611
super(c1,c2);
712
this.matrixIdx = matrixIdx;
813
}
14+
15+
@Override
16+
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG) {
17+
return inputs.get(matrixIdx);
18+
}
919
}

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java

Lines changed: 133 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,30 @@
11
package org.apache.sysds.runtime.einsum;
22

3+
import org.apache.commons.logging.Log;
4+
import org.apache.sysds.runtime.codegen.SpoofRowwise;
5+
import org.apache.sysds.runtime.functionobjects.Plus;
6+
import org.apache.sysds.runtime.functionobjects.ReduceCol;
7+
import org.apache.sysds.runtime.functionobjects.ReduceRow;
8+
import org.apache.sysds.runtime.functionobjects.SwapIndex;
39
import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
410
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
11+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
12+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
13+
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
14+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
15+
import org.jetbrains.annotations.NotNull;
516

617
import java.util.ArrayList;
718
import java.util.Arrays;
819
import java.util.HashMap;
920
import java.util.HashSet;
1021
import java.util.List;
22+
import java.util.stream.Collectors;
23+
24+
import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
1125

1226
public class EOpNodeEinsumFuse extends EOpNode {
13-
public static final int AB_index=0;
14-
public static final int BA_index=1;
15-
public static final int B_index=2;
16-
public static final int XB_index=3;
17-
public static final int BX_index=4;
18-
public static final int A_index=5;
19-
public static final int XA_index=6;
20-
public static final int AX_index=7;
21-
public static final int AZ_index=8;
27+
2228
public enum EinsumRewriteType{
2329
// B -> row*row, A -> row*scalar
2430
AB_BA_B_A__AB,
@@ -299,5 +305,123 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) {
299305
ret.add(e);
300306
return e;
301307
}
308+
309+
@Override
310+
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, Log LOG) {
311+
List<List<MatrixBlock>> mbs = operands.stream().map(l -> l.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).collect(Collectors.toList())).toList();
312+
var eOpNodeEinsumFuse = this;
313+
314+
if( LOG.isTraceEnabled()) {
315+
String x = eOpNodeEinsumFuse.operands.stream()
316+
.flatMap(List::stream)
317+
.map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2))
318+
.collect(Collectors.joining(","));
319+
String res = (eOpNodeEinsumFuse.c1 == null ? "" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString());
320+
LOG.trace("ComputeEOpNodeFuse " + eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res);
321+
}
322+
boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB;
323+
boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A;
324+
boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B;
325+
boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__;
326+
boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z;
327+
boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ;
328+
boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB;
329+
// boolean isResultBC = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AC__BC;
330+
// boolean isResultCB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB;
331+
List<MatrixBlock> ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7);
332+
List<MatrixBlock> AZs = mbs.get(8);
333+
List<MatrixBlock> Zs = mbs.get(9);
334+
// List<MatrixBlock> ACs = isResultBC || isResultCB ? mbs.get(10) : null;
335+
int bSize = ABs.get(0).getNumColumns();
336+
int aSize = ABs.get(0).getNumRows();
337+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
338+
for(MatrixBlock mb: BAs){//BA->AB
339+
ABs.add(mb.reorgOperations(transpose, null,0,0,0));
340+
}
341+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
342+
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
343+
for(MatrixBlock mb: XBs){//XB->B
344+
MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false);
345+
Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null));
346+
}
347+
for(MatrixBlock mb: XAs){//XA->A
348+
MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false);
349+
As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null));
350+
}
351+
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
352+
for(MatrixBlock mb: BXs){//BX->B
353+
MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false);
354+
As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null));
355+
}
356+
for(MatrixBlock mb: AXs){//AX->B // todo remove all X
357+
MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false);
358+
As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null));
359+
}
360+
if(As.size()>1){
361+
As = multiplyVectorsIntoOne(As, aSize);
362+
}
363+
if(Bs.size() > 1){
364+
Bs = multiplyVectorsIntoOne(Bs, bSize);
365+
}
366+
if(Zs != null && Zs.size() > 1){
367+
Zs = multiplyVectorsIntoOne(Zs, AZs.get(0).getNumColumns());
368+
}
369+
int constDim2 = -1;
370+
int zSize = 0;
371+
int azCount = 0;
372+
int zCount = 0;
373+
switch(eOpNodeEinsumFuse.einsumRewriteType){
374+
case AB_BA_B_A_AZ__Z -> {
375+
constDim2 = AZs.get(0).getNumColumns();
376+
zSize = AZs.get(0).getNumColumns();
377+
azCount = AZs.size();
378+
if (Zs != null) zCount = Zs.size();
379+
}
380+
case AB_BA_B_A_AZ__BZ, AB_BA_B_A_AZ__ZB -> {
381+
constDim2 = AZs.get(0).getNumColumns();
382+
zSize = AZs.get(0).getNumColumns();
383+
azCount = AZs.size();
384+
}
385+
}
386+
387+
SpoofRowwise.RowType rowType = switch(eOpNodeEinsumFuse.einsumRewriteType){
388+
case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG;
389+
case AB_BA_B_A__B -> SpoofRowwise.RowType.COL_AGG_T;
390+
case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG;
391+
case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG;
392+
case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST;
393+
case AB_BA_B_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T;
394+
case AB_BA_B_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1;
395+
};
396+
EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 0, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize);
397+
398+
399+
ArrayList<MatrixBlock> fuseInputs = new ArrayList<>();
400+
// inputs.add(resBlock);
401+
402+
fuseInputs.addAll(ABs);
403+
fuseInputs.addAll(Bs);
404+
fuseInputs.addAll(As);
405+
if (isResultZ || isResultBZ || isResultZB)
406+
fuseInputs.addAll(AZs);
407+
MatrixBlock out = r.execute(fuseInputs, new ArrayList<>(), new MatrixBlock(), numThreads);
408+
if( isResultA || isResultB || isResultZ)
409+
ensureMatrixBlockColumnVector(out);
410+
return out;
411+
412+
}
413+
414+
private static @NotNull List<MatrixBlock> multiplyVectorsIntoOne(List<MatrixBlock> mbs, int size) {
415+
MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false);
416+
mb.allocateDenseBlock();
417+
for(int i = 1; i< mbs.size(); i++) { // multiply Bs
418+
if(i==1){
419+
LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size);
420+
}else{
421+
LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size);
422+
}
423+
}
424+
return List.of(mb);
425+
}
302426
}
303427

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)