@@ -82,7 +82,7 @@ public void processInstruction(ExecutionContext ec) {
8282 EinsumContext einc = EinsumContext .getEinsumContext (eqStr , inputs );
8383
8484 this .einc = einc ;
85- String resultString = einc .outChar2 != null ? String .valueOf (einc .outChar1 ) + einc .outChar2 : einc .outChar1 != null ? String .valueOf (einc .outChar1 ) : null ;
85+ String resultString = einc .outChar2 != null ? String .valueOf (einc .outChar1 ) + einc .outChar2 : einc .outChar1 != null ? String .valueOf (einc .outChar1 ) : "" ;
8686
8787 if ( LOG .isDebugEnabled () ) LOG .trace ("outrows:" +einc .outRows +", outcols:" +einc .outCols );
8888
@@ -184,7 +184,9 @@ public void processInstruction(ExecutionContext ec) {
184184 throw new RuntimeException ("Einsum runtime error, reductions and multiplications finished but the did not produce one result" ); // should not happen
185185 }
186186 }
187- ec .setMatrixOutput (output .getName (), res );
187+ if (einc .outRows == 1 && einc .outCols == 1 )
188+ ec .setScalarOutput (output .getName (), new DoubleObject (res .get (0 , 0 )));
189+ else ec .setMatrixOutput (output .getName (), res );
188190 }
189191
190192 else { // if (needToDoCellTemplate)
@@ -294,6 +296,7 @@ private static void EnsureMatrixBlockRowVector(MatrixBlock mb){
294296 /* handle situation: ji,ji or ij,ji, j,j */
295297 private boolean multiplyTerms (HashMap <Character , ArrayList <Integer >> partsCharactersToIndices , ArrayList <MatrixBlock > inputs , ArrayList <String > inputsChars , Character outChar1 , Character outChar2 ) {
296298 HashMap <String , LinkedList <Integer >> matrixStringToIndex = new HashMap <>();
299+ HashSet <String > matrixStringToIndexSkip = new HashSet <>();
297300 HashMap <String , LinkedList <Integer >> vectorStringToIndex = new HashMap <>();
298301
299302 for (int i = 0 ; i < inputsChars .size (); i ++){
@@ -352,7 +355,7 @@ private boolean multiplyTerms(HashMap<Character, ArrayList<Integer>> partsCharac
352355 }
353356
354357 for (var s : matrixStringToIndex .keySet ()){
355- if (! matrixStringToIndex . containsKey (s )) continue ; // entries can be removed
358+ if (matrixStringToIndexSkip . contains (s )) continue ;
356359
357360 String sT = s .length () == 2 ? String .valueOf (s .charAt (1 )) + s .charAt (0 ) : null ;
358361 LinkedList <Integer > idxs = matrixStringToIndex .get (s );
@@ -459,7 +462,7 @@ private boolean multiplyTerms(HashMap<Character, ArrayList<Integer>> partsCharac
459462 if (partsCharactersToIndices .containsKey (c )) partsCharactersToIndices .get (c ).add (inputs .size () - 1 );
460463 }
461464
462- if (idxsT != null ) matrixStringToIndex . remove (sT );
465+ if (idxsT != null ) matrixStringToIndexSkip . add (sT );
463466 }
464467
465468
@@ -542,9 +545,11 @@ private Pair<MatrixBlock, String> computeRowSummation(List<Integer> toSum, Array
542545 String resS ;
543546 SumOperation sumOp ;
544547
545- if (s1 .length ()==1 && s2 .length () == 1 ){ //remove never happening here
548+ if (s1 .length ()==1 && s2 .length () == 1 ){
546549 sumOp = SumOperation .a_a ;
547550 resS = "" ;
551+ first = inputs .get (toSum .get (0 ));
552+ second = inputs .get (toSum .get (1 ));
548553 }
549554 else if (s2 .length () == 1 || s1 .length () == 1 ){
550555 if (s1 .length () == 1 ){
0 commit comments