Skip to content

Commit 8f7451f

Browse files
authored
Merge branch 'apache:main' into future
2 parents cd7e229 + 9484f11 commit 8f7451f

File tree

161 files changed

+5227
-1679
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

161 files changed

+5227
-1679
lines changed

.github/workflows/javaTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ jobs:
154154
run: mvn jacoco:report
155155

156156
- name: Upload coverage to Codecov
157-
uses: codecov/codecov-action@v5.0.2
157+
uses: codecov/codecov-action@v5.1.2
158158
if: github.repository_owner == 'apache'
159159
with:
160160
fail_ci_if_error: false

scripts/builtin/sqrtMatrix.dml

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Computes the matrix square root B of a matrix A, such that
23+
# A = B %*% B.
24+
#
25+
# INPUT:
26+
# ------------------------------------------------------------------------------
27+
# A Input Matrix A
28+
# S Strategy (COMMON .. java-based commons-math, DML)
29+
# ------------------------------------------------------------------------------
30+
#
31+
# OUTPUT:
32+
# ------------------------------------------------------------------------------
33+
# B Output Matrix B
34+
# ------------------------------------------------------------------------------
35+
36+
37+
m_sqrtMatrix = function(Matrix[Double] A, String S)
38+
return(Matrix[Double] B)
39+
{
40+
if (S == "COMMON") {
41+
B = sqrtMatrixJava(A)
42+
} else if (S == "DML") {
43+
N = nrow(A);
44+
D = ncol(A);
45+
46+
#check that matrix is square
47+
if (D != N){
48+
stop("matrixSqrt Input Error: matrix not square!")
49+
}
50+
51+
# Any non singualar square matrix has a square root
52+
isDiag = isDiagonal(A)
53+
if(isDiag) {
54+
B = sqrtDiagMatrix(A);
55+
} else {
56+
[eValues, eVectors] = eigen(A);
57+
58+
hasNonNegativeEigenValues = (sum(eValues >= 0) == length(eValues));
59+
60+
if(!hasNonNegativeEigenValues) {
61+
stop("matrixSqrt exec Error: matrix has imaginary square root");
62+
}
63+
64+
isSymmetric = sum(A == t(A)) == length(A);
65+
allEigenValuesUnique = length(eValues) == length(unique(eValues));
66+
67+
if(allEigenValuesUnique | isSymmetric) {
68+
# calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
69+
sqrtD = sqrtDiagMatrix(diag(eValues));
70+
V_Inv = inv(eVectors);
71+
B = eVectors %*% sqrtD %*% V_Inv;
72+
} else {
73+
#formular: (Denman–Beavers iteration)
74+
Y = A
75+
#identity matrix
76+
Z = diag(matrix(1.0, rows=N, cols=1))
77+
78+
for (x in 1:100) {
79+
Y_new = (1 / 2) * (Y + inv(Z))
80+
Z_new = (1 / 2) * (Z + inv(Y))
81+
Y = Y_new
82+
Z = Z_new
83+
}
84+
B = Y
85+
}
86+
}
87+
} else {
88+
stop("Error: Unknown strategy for matrix square root.")
89+
}
90+
}
91+
92+
# assumes square and diagonal matrix
93+
sqrtDiagMatrix = function(Matrix[Double] X)
94+
return(Matrix[Double] sqrt_x)
95+
{
96+
N = nrow(X);
97+
98+
#check if identity matrix
99+
is_identity = sum(diag(diag(X)) == X)==length(X)
100+
& sum(diag(X) == matrix(1,nrow(X),1))==nrow(X);
101+
102+
if(is_identity)
103+
sqrt_x = X;
104+
else
105+
sqrt_x = diag(sqrt(diag(X)));
106+
}
107+
108+
isDiagonal = function (Matrix[Double] X)
109+
return(boolean diagonal)
110+
{
111+
#all cells should be the same to be diagonal
112+
diagonal = sum(diag(diag(X)) == X) == length(X);
113+
}
114+

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ public enum Builtins {
325325
STEPLM("steplm",true, ReturnType.MULTI_RETURN),
326326
STFT("stft", false, ReturnType.MULTI_RETURN),
327327
SQRT("sqrt", false),
328+
SQRT_MATRIX("sqrtMatrix", true),
329+
SQRT_MATRIX_JAVA("sqrtMatrixJava", false, ReturnType.SINGLE_RETURN),
328330
SUM("sum", false),
329331
SVD("svd", false, ReturnType.MULTI_RETURN),
330332
TABLE("table", "ctable", false),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ public enum OpOp1 {
542542
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
543543
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
544544
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
545-
SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
545+
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
546546
//fused ML-specific operators for performance
547547
SPROP, //sample proportion: P * (1 - P)
548548
SIGMOID, //sigmoid function: 1 / (1 + exp(-X))

src/main/java/org/apache/sysds/conf/DMLConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ public class DMLConfig
201201
_defaultVals.put(FLOATING_POINT_PRECISION, "double" );
202202
_defaultVals.put(USE_SSL_FEDERATED_COMMUNICATION, "false");
203203
_defaultVals.put(DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, "10");
204-
_defaultVals.put(FEDERATED_TIMEOUT, "-1");
204+
_defaultVals.put(FEDERATED_TIMEOUT, "86400"); // default 1 day compute timeout.
205205
_defaultVals.put(FEDERATED_PLANNER, FederatedPlanner.RUNTIME.name());
206206
_defaultVals.put(FEDERATED_PAR_CONN, "-1"); // vcores
207207
_defaultVals.put(FEDERATED_PAR_INST, "-1"); // vcores

src/main/java/org/apache/sysds/hops/UnaryOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent
512512

513513
//ensure cp exec type for single-node operations
514514
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.TYPEOF
515-
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD
515+
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
516516
|| getInput().get(0).getDataType() == DataType.LIST || isMetadataOperation() )
517517
{
518518
_etype = ExecType.CP;

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
115115
if( LineageCacheConfig.getCompAssRW() )
116116
_sbRuleSet.add( new MarkForLineageReuse() );
117117
_sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() );
118-
}
118+
_dagRuleSet.add( new RewriteNonScalarPrint() );
119+
}
119120

120121
// DYNAMIC REWRITES (which do require size information)
121122
if( dynamicRewrites )

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -381,30 +381,28 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)
381381

382382
return hi;
383383
}
384-
385-
private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
386-
{
387-
if( hi instanceof ReorgOp )
388-
{
384+
385+
private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) {
386+
if( hi instanceof ReorgOp ) {
389387
ReorgOp rop = (ReorgOp) hi;
390-
Hop input = hi.getInput(0);
388+
Hop input = hi.getInput(0);
391389
boolean apply = false;
392-
393-
//equal dims of reshape input and output -> no need for reshape because
390+
391+
//equal dims of reshape input and output -> no need for reshape because
394392
//byrow always refers to both input/output and hence gives the same result
395393
apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input));
396-
397-
//1x1 dimensions of transpose/reshape -> no need for reorg
398-
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE)
399-
&& rop.getDim1()==1 && rop.getDim2()==1);
400-
394+
395+
//1x1 dimensions of transpose/reshape/roll -> no need for reorg
396+
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE
397+
|| rop.getOp()==ReOrgOp.ROLL) && rop.getDim1()==1 && rop.getDim2()==1);
398+
401399
if( apply ) {
402400
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
403401
hi = input;
404402
LOG.debug("Applied removeUnnecessaryReorg.");
405403
}
406404
}
407-
405+
408406
return hi;
409407
}
410408

@@ -1356,44 +1354,78 @@ else if ( applyRight ) {
13561354
* @param pos position
13571355
* @return high-level operator
13581356
*/
1359-
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
1357+
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
13601358
{
13611359
//all patterns headed by full sum over binary operation
13621360
if( hi instanceof AggUnaryOp //full sum root over binaryop
1363-
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
1364-
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
1365-
&& hi.getInput(0) instanceof BinaryOp
1366-
&& hi.getInput(0).getParent().size()==1 ) //single parent
1361+
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
1362+
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
1363+
&& hi.getInput(0) instanceof BinaryOp
1364+
&& hi.getInput(0).getParent().size()==1 ) //single parent
13671365
{
13681366
BinaryOp bop = (BinaryOp) hi.getInput(0);
13691367
Hop left = bop.getInput(0);
13701368
Hop right = bop.getInput(1);
1371-
1372-
if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B)
1373-
&& left.getDataType() == DataType.MATRIX
1374-
&& right.getDataType() == DataType.MATRIX )
1369+
1370+
if( left.getDataType() == DataType.MATRIX
1371+
&& right.getDataType() == DataType.MATRIX )
13751372
{
13761373
OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
13771374
|| bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
13781375
? bop.getOp() : null;
1379-
1376+
13801377
if( applyOp != null ) {
1381-
//create new subdag sum(A) bop sum(B)
1382-
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1383-
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1384-
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
1385-
1386-
//rewire new subdag
1387-
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1388-
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1389-
1390-
hi = newBin;
1391-
1392-
LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
1378+
if (HopRewriteUtils.isEqualSize(left, right)) {
1379+
//create new subdag sum(A) bop sum(B) for equal-sized matrices
1380+
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1381+
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1382+
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
1383+
//rewire new subdag
1384+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1385+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1386+
1387+
hi = newBin;
1388+
1389+
LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
1390+
}
1391+
// Check if right operand is a vector (has dimension of 1 in either rows or columns)
1392+
else if (right.getDim1() == 1 || right.getDim2() == 1) {
1393+
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1394+
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1395+
1396+
// Row vector case (1 x n)
1397+
if (right.getDim1() == 1) {
1398+
// Create nrow(A) operation using dimensions
1399+
UnaryOp nRows = HopRewriteUtils.createUnary(left, OpOp1.NROW);
1400+
BinaryOp scaledSum = HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT);
1401+
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
1402+
//rewire new subdag
1403+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1404+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1405+
1406+
hi = newBin;
1407+
1408+
LOG.debug("Applied pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+").");
1409+
}
1410+
// Column vector case (n x 1)
1411+
else if (right.getDim2() == 1) {
1412+
// Create ncol(A) operation using dimensions
1413+
UnaryOp nCols = HopRewriteUtils.createUnary(left, OpOp1.NCOL);
1414+
BinaryOp scaledSum = HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT);
1415+
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
1416+
//rewire new subdag
1417+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1418+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1419+
1420+
hi = newBin;
1421+
1422+
LOG.debug("Applied pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+").");
1423+
}
1424+
}
13931425
}
13941426
}
13951427
}
1396-
1428+
13971429
return hi;
13981430
}
13991431

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
197197
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
198198
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
199199

200-
hi = fixNonScalarPrint(hop, hi, i); //e.g., print(m) -> print(toString(m))
201200

202201
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
203202
if( !descendFirst )
@@ -2131,20 +2130,6 @@ else if(HopRewriteUtils.isBinary(binaryOperator, OpOp2.EQUAL)) {
21312130
return hi;
21322131
}
21332132

2134-
private static Hop fixNonScalarPrint(Hop parent, Hop hi, int pos) {
2135-
if(HopRewriteUtils.isUnary(parent, OpOp1.PRINT) && !hi.getDataType().isScalar()) {
2136-
LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
2137-
args.put("target", hi);
2138-
Hop newHop = HopRewriteUtils.createParameterizedBuiltinOp(
2139-
hi, args, ParamBuiltinOp.TOSTRING);
2140-
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
2141-
hi = newHop;
2142-
LOG.debug("Applied fixNonScalarPrint (line " + hi.getBeginLine() + ")");
2143-
}
2144-
2145-
return hi;
2146-
}
2147-
21482133
/**
21492134
* NOTE: currently disabled since this rewrite is INVALID in the
21502135
* presence of NaNs (because (NaN!=NaN) is true).

0 commit comments

Comments
 (0)