|
23 | 23 | import org.apache.commons.logging.Log; |
24 | 24 | import org.apache.commons.logging.LogFactory; |
25 | 25 | import org.apache.sysds.common.Types.DataType; |
| 26 | +import org.apache.sysds.hops.LiteralOp; |
26 | 27 | import org.apache.sysds.hops.OptimizerUtils; |
| 28 | +import org.apache.sysds.hops.codegen.cplan.CNode; |
27 | 29 | import org.apache.sysds.hops.codegen.cplan.CNodeCell; |
| 30 | +import org.apache.sysds.hops.codegen.cplan.CNodeData; |
28 | 31 | import org.apache.sysds.hops.codegen.cplan.CNodeRow; |
29 | 32 | import org.apache.sysds.runtime.codegen.*; |
30 | 33 | import org.apache.sysds.runtime.compress.CompressedMatrixBlock; |
@@ -164,7 +167,9 @@ public void processInstruction(ExecutionContext ec) { |
164 | 167 |
|
165 | 168 | if(inputsChars.length == 2 && inputsChars[0].charAt(0)==inputsChars[1].charAt(0) && !einc.summingChars.contains(parts[1].charAt(0))){// ja,jb->... |
166 | 169 | // outer tmpl |
167 | | - CNodeRow cnode = new CNodeRow(new ArrayList<>(), null); |
| 170 | + ArrayList<CNode> cnodeIn = new ArrayList<>(); |
| 171 | + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); |
| 172 | + CNodeRow cnode = new CNodeRow(cnodeIn, null); |
168 | 173 | // cnode.setConstDim2(einc.outCols); |
169 | 174 | // cnode.setNumVectorIntermediates(1); |
170 | 175 | String src = tmpRow; |
@@ -208,8 +213,9 @@ public void processInstruction(ExecutionContext ec) { |
208 | 213 | else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].charAt(0)){ |
209 | 214 | ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier |
210 | 215 | MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0 ,0, 0); |
211 | | - |
212 | | - CNodeRow cnode = new CNodeRow(new ArrayList<>(), null); |
| 216 | + ArrayList<CNode> cnodeIn = new ArrayList<>(); |
| 217 | + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); |
| 218 | + CNodeRow cnode = new CNodeRow(cnodeIn, null); |
213 | 219 | String src = tmpRow; |
214 | 220 |
|
215 | 221 | if(einc.outCols == 1){ |
@@ -246,7 +252,9 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char |
246 | 252 | } |
247 | 253 | } |
248 | 254 | else{ //fallback to cell |
249 | | - CNodeCell cnode = new CNodeCell(new ArrayList<>(), null); |
| 255 | + ArrayList<CNode> cnodeIn = new ArrayList<>(); |
| 256 | + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); |
| 257 | + CNodeCell cnode = new CNodeCell(cnodeIn, null); |
250 | 258 | // cnode.setCellType(SpoofCellwise.CellType.NO_AGG); |
251 | 259 | StringBuilder sb = new StringBuilder(); |
252 | 260 |
|
|
0 commit comments