Skip to content

Commit 9e1ae4f

Browse files
committed
Merge branch 'feat/java-matrix-sqrt-implementation' into feat/sqrt-matrix-combined
2 parents ad5be66 + 1c1c71d commit 9e1ae4f

File tree

7 files changed

+34
-4
lines changed

7 files changed

+34
-4
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ public enum Builtins {
326326
STFT("stft", false, ReturnType.MULTI_RETURN),
327327
SQRT("sqrt", false),
328328
SQRT_MATRIX("matrixSqrt", true),
329+
SQRT_MATRIX_JAVA("sqrt_matrix_java", false, ReturnType.SINGLE_RETURN),
329330
SUM("sum", false),
330331
SVD("svd", false, ReturnType.MULTI_RETURN),
331332
TABLE("table", "ctable", false),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ public enum OpOp1 {
542542
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
543543
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
544544
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
545-
SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
545+
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
546546
//fused ML-specific operators for performance
547547
SPROP, //sample proportion: P * (1 - P)
548548
SIGMOID, //sigmoid function: 1 / (1 + exp(-X))

src/main/java/org/apache/sysds/hops/UnaryOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent
512512

513513
//ensure cp exec type for single-node operations
514514
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.TYPEOF
515-
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD
515+
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
516516
|| getInput().get(0).getDataType() == DataType.LIST || isMetadataOperation() )
517517
{
518518
_etype = ExecType.CP;

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,19 @@ && isConstant(in[2])
17591759
output.setDimensions(in.getDim1(), in.getDim2());
17601760
output.setBlocksize(in.getBlocksize());
17611761
break;
1762+
1763+
case SQRT_MATRIX_JAVA:
1764+
1765+
checkNumParameters(1);
1766+
checkMatrixParam(getFirstExpr());
1767+
output.setDataType(DataType.MATRIX);
1768+
output.setValueType(ValueType.FP64);
1769+
Identifier sqrt = getFirstExpr().getOutput();
1770+
if(sqrt.dimsKnown() && sqrt.getDim1() != sqrt.getDim2())
1771+
raiseValidateError("Input to sqrtMatrix() must be square matrix -- given: a " + sqrt.getDim1() + "x" + sqrt.getDim2() + " matrix.", conditional);
1772+
output.setDimensions( sqrt.getDim1(), sqrt.getDim2());
1773+
output.setBlocksize( sqrt.getBlocksize());
1774+
break;
17621775

17631776
case CHOLESKY:
17641777
{

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,6 +2749,7 @@ else if ( in.length == 2 )
27492749
break;
27502750

27512751
case INVERSE:
2752+
case SQRT_MATRIX_JAVA:
27522753
case CHOLESKY:
27532754
case TYPEOF:
27542755
case DETECTSCHEMA:

src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ public class CPInstructionParser extends InstructionParser {
208208
String2CPInstructionType.put( "ucummax", CPType.Unary);
209209
String2CPInstructionType.put( "stop" , CPType.Unary);
210210
String2CPInstructionType.put( "inverse", CPType.Unary);
211+
String2CPInstructionType.put( "sqrt_matrix_java", CPType.Unary);
211212
String2CPInstructionType.put( "cholesky",CPType.Unary);
212213
String2CPInstructionType.put( "sprop", CPType.Unary);
213214
String2CPInstructionType.put( "sigmoid", CPType.Unary);

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ private LibCommonsMath() {
8080
}
8181

8282
public static boolean isSupportedUnaryOperation( String opcode ) {
83-
return ( opcode.equals("inverse") || opcode.equals("cholesky") );
83+
return ( opcode.equals("inverse") || opcode.equals("cholesky") || opcode.equals("sqrt_matrix_java") );
8484
}
8585

8686
public static boolean isSupportedMultiReturnOperation( String opcode ) {
@@ -111,6 +111,8 @@ public static MatrixBlock unaryOperations(MatrixBlock inj, String opcode) {
111111
return computeMatrixInverse(matrixInput);
112112
else if (opcode.equals("cholesky"))
113113
return computeCholesky(matrixInput);
114+
else if (opcode.equals("sqrt_matrix_java"))
115+
return computeSqrt(inj);
114116
return null;
115117
}
116118

@@ -512,7 +514,19 @@ private static MatrixBlock[] computeSvd(MatrixBlock in) {
512514

513515
return new MatrixBlock[] { U, Sigma, V };
514516
}
515-
517+
518+
/**
519+
* Computes the square root of a matrix Calls Apache Commons Math EigenDecomposition.
520+
*
521+
* @param in Input matrix
522+
* @return matrix block
523+
*/
524+
private static MatrixBlock computeSqrt(MatrixBlock in) {
525+
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
526+
EigenDecomposition ed = new EigenDecomposition(matrixInput);
527+
return DataConverter.convertToMatrixBlock(ed.getSquareRoot());
528+
}
529+
516530
/**
517531
* Function to compute matrix inverse via matrix decomposition.
518532
*

0 commit comments

Comments
 (0)