Skip to content

Commit 7eb950c

Browse files
create einsum test files from configuration list, fixes in code
1 parent b854305 commit 7eb950c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+264
-1367
lines changed

src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public void processInstruction(ExecutionContext ec) {
8282
EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs);
8383

8484
this.einc = einc;
85-
String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : null;
85+
String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : "";
8686

8787
if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols);
8888

@@ -184,7 +184,9 @@ public void processInstruction(ExecutionContext ec) {
184184
throw new RuntimeException("Einsum runtime error, reductions and multiplications finished but the did not produce one result"); // should not happen
185185
}
186186
}
187-
ec.setMatrixOutput(output.getName(), res);
187+
if (einc.outRows == 1 && einc.outCols == 1)
188+
ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0)));
189+
else ec.setMatrixOutput(output.getName(), res);
188190
}
189191

190192
else { // if (needToDoCellTemplate)
@@ -294,6 +296,7 @@ private static void EnsureMatrixBlockRowVector(MatrixBlock mb){
294296
/* handle situation: ji,ji or ij,ji, j,j */
295297
private boolean multiplyTerms(HashMap<Character, ArrayList<Integer>> partsCharactersToIndices, ArrayList<MatrixBlock> inputs, ArrayList<String> inputsChars, Character outChar1, Character outChar2 ) {
296298
HashMap<String, LinkedList<Integer>> matrixStringToIndex = new HashMap<>();
299+
HashSet<String> matrixStringToIndexSkip = new HashSet<>();
297300
HashMap<String, LinkedList<Integer>> vectorStringToIndex = new HashMap<>();
298301

299302
for(int i = 0; i < inputsChars.size(); i++){
@@ -352,7 +355,7 @@ private boolean multiplyTerms(HashMap<Character, ArrayList<Integer>> partsCharac
352355
}
353356

354357
for(var s : matrixStringToIndex.keySet()){
355-
if(!matrixStringToIndex.containsKey(s)) continue; // entries can be removed
358+
if(matrixStringToIndexSkip.contains(s)) continue;
356359

357360
String sT = s.length() == 2 ? String.valueOf(s.charAt(1)) + s.charAt(0) : null;
358361
LinkedList<Integer> idxs = matrixStringToIndex.get(s);
@@ -459,7 +462,7 @@ private boolean multiplyTerms(HashMap<Character, ArrayList<Integer>> partsCharac
459462
if (partsCharactersToIndices.containsKey(c)) partsCharactersToIndices.get(c).add(inputs.size() - 1);
460463
}
461464

462-
if(idxsT != null) matrixStringToIndex.remove(sT);
465+
if(idxsT != null) matrixStringToIndexSkip.add(sT);
463466
}
464467

465468

@@ -542,9 +545,11 @@ private Pair<MatrixBlock, String> computeRowSummation(List<Integer> toSum, Array
542545
String resS;
543546
SumOperation sumOp;
544547

545-
if(s1.length()==1 && s2.length() == 1){ //remove never happening here
548+
if(s1.length()==1 && s2.length() == 1){
546549
sumOp = SumOperation.a_a;
547550
resS = "";
551+
first = inputs.get(toSum.get(0));
552+
second = inputs.get(toSum.get(1));
548553
}
549554
else if(s2.length() == 1 || s1.length() == 1){
550555
if(s1.length() == 1){

0 commit comments

Comments
 (0)