Skip to content

Commit c9c9947

Browse files
quick fix
1 parent 8db99c9 commit c9c9947

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
import org.apache.commons.logging.Log;
2424
import org.apache.commons.logging.LogFactory;
2525
import org.apache.sysds.common.Types.DataType;
26+
import org.apache.sysds.hops.LiteralOp;
2627
import org.apache.sysds.hops.OptimizerUtils;
28+
import org.apache.sysds.hops.codegen.cplan.CNode;
2729
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
30+
import org.apache.sysds.hops.codegen.cplan.CNodeData;
2831
import org.apache.sysds.hops.codegen.cplan.CNodeRow;
2932
import org.apache.sysds.runtime.codegen.*;
3033
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
@@ -164,7 +167,9 @@ public void processInstruction(ExecutionContext ec) {
164167

165168
if(inputsChars.length == 2 && inputsChars[0].charAt(0)==inputsChars[1].charAt(0) && !einc.summingChars.contains(parts[1].charAt(0))){// ja,jb->...
166169
// outer tmpl
167-
CNodeRow cnode = new CNodeRow(new ArrayList<>(), null);
170+
ArrayList<CNode> cnodeIn = new ArrayList<>();
171+
cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR));
172+
CNodeRow cnode = new CNodeRow(cnodeIn, null);
168173
// cnode.setConstDim2(einc.outCols);
169174
// cnode.setNumVectorIntermediates(1);
170175
String src = tmpRow;
@@ -208,8 +213,9 @@ public void processInstruction(ExecutionContext ec) {
208213
else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].charAt(0)){
209214
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier
210215
MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0 ,0, 0);
211-
212-
CNodeRow cnode = new CNodeRow(new ArrayList<>(), null);
216+
ArrayList<CNode> cnodeIn = new ArrayList<>();
217+
cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR));
218+
CNodeRow cnode = new CNodeRow(cnodeIn, null);
213219
String src = tmpRow;
214220

215221
if(einc.outCols == 1){
@@ -246,7 +252,9 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char
246252
}
247253
}
248254
else{ //fallback to cell
249-
CNodeCell cnode = new CNodeCell(new ArrayList<>(), null);
255+
ArrayList<CNode> cnodeIn = new ArrayList<>();
256+
cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR));
257+
CNodeCell cnode = new CNodeCell(cnodeIn, null);
250258
// cnode.setCellType(SpoofCellwise.CellType.NO_AGG);
251259
StringBuilder sb = new StringBuilder();
252260

0 commit comments

Comments
 (0)