Skip to content

Commit c66b39f

Browse files
committed
[SYSTEMDS-2864] Fix opcode merge conflicts and autodiff bug
* revert the bad merge of the previous sqrt_matrix modification * fix the handling of opcodes in the autodiff program reconstruction
1 parent c416664 commit c66b39f

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ public enum Opcodes {
150150
ATAN("atan", CPType.Unary),
151151
SIGN("sign", CPType.Unary),
152152
SQRT("sqrt", CPType.Unary),
153+
SQRT_MATRIX_JAVA("sqrt_matrix_java", CPType.Unary),
153154
PLOGP("plogp", CPType.Unary),
154155
PRINT("print", CPType.Unary),
155156
ASSERT("assert", CPType.Unary),

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ private LibCommonsMath() {
8181
}
8282

8383
public static boolean isSupportedUnaryOperation( String opcode ) {
84-
return ( opcode.equals(Opcodes.INVERSE.toString()) || opcode.equals(Opcodes.CHOLESKY.toString()) );
84+
return opcode.equals(Opcodes.INVERSE.toString())
85+
|| opcode.equals(Opcodes.CHOLESKY.toString())
86+
|| opcode.equals(Opcodes.SQRT_MATRIX_JAVA.toString());
8587
}
8688

8789
public static boolean isSupportedMultiReturnOperation( String opcode ) {
@@ -107,11 +109,15 @@ public static boolean isSupportedMatrixMatrixOperation( String opcode ) {
107109
}
108110

109111
public static MatrixBlock unaryOperations(MatrixBlock inj, String opcode) {
110-
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(inj);
111-
if(opcode.equals(Opcodes.INVERSE.toString()))
112-
return computeMatrixInverse(matrixInput);
113-
else if (opcode.equals(Opcodes.CHOLESKY.toString()))
114-
return computeCholesky(matrixInput);
112+
if (opcode.equals(Opcodes.SQRT_MATRIX_JAVA.toString()))
113+
return computeSqrt(inj);
114+
else {
115+
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(inj);
116+
if(opcode.equals(Opcodes.INVERSE.toString()))
117+
return computeMatrixInverse(matrixInput);
118+
else if (opcode.equals(Opcodes.CHOLESKY.toString()))
119+
return computeCholesky(matrixInput);
120+
}
115121
return null;
116122
}
117123

@@ -514,6 +520,18 @@ private static MatrixBlock[] computeSvd(MatrixBlock in) {
514520
return new MatrixBlock[] { U, Sigma, V };
515521
}
516522

523+
/**
524+
* Computes the square root of a matrix Calls Apache Commons Math EigenDecomposition.
525+
*
526+
* @param in Input matrix
527+
* @return matrix block
528+
*/
529+
private static MatrixBlock computeSqrt(MatrixBlock in) {
530+
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
531+
EigenDecomposition ed = new EigenDecomposition(matrixInput);
532+
return DataConverter.convertToMatrixBlock(ed.getSquareRoot());
533+
}
534+
517535
/**
518536
* Function to compute matrix inverse via matrix decomposition.
519537
*

src/main/java/org/apache/sysds/runtime/util/AutoDiff.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.Stack;
2727

2828
import org.apache.commons.lang3.mutable.MutableInt;
29+
import org.apache.sysds.common.Opcodes;
2930
import org.apache.sysds.common.Types;
3031
import org.apache.sysds.hops.DataGenOp;
3132
import org.apache.sysds.hops.DataOp;
@@ -213,7 +214,7 @@ else if(inst instanceof RandSPInstruction) {
213214
break;
214215
}
215216
case Instruction: {
216-
CPInstruction.CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
217+
CPInstruction.CPType ctype = Opcodes.getCPTypeByOpcode(item.getOpcode());
217218

218219
if(ctype != null) {
219220
switch(ctype) {

0 commit comments

Comments
 (0)