Skip to content

Commit 341cf3f

Browse files
committed
[SYSTEMDS-3868] Fix new hoisting of function calls from if statements
This patch fixes a compilation issue, where certain if-branches were lost during hoisting of function calls from if statements. The issue did not show up, because for test instances of function calls in if statements triggered rewrites of predicate constant folding and branch removal. Now, we also resolved all remaining FIXMEs in the new ampute() builtin function.
1 parent 40f9e54 commit 341cf3f

File tree

4 files changed

+62
-7
lines changed

4 files changed

+62
-7
lines changed

scripts/builtin/ampute.dml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ u_validateInputs = function(Matrix[Double] X, Double prop, Matrix[Double] freq,
9191
return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
9292

9393
errors = list()
94-
freqProvided = !u_isEmpty(freq) # FIX ME
95-
patternsProvided = !u_isEmpty(patterns) # FIX ME
96-
weightsProvided = !u_isEmpty(weights) # FIX ME
9794

9895
# About the input dataset:
9996
if (max(is.na(X)) == 1) {
@@ -107,7 +104,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
107104
if (mech != "MAR" & mech != "MCAR" & mech != "MNAR") {
108105
errors = append(errors, "Invalid option provided for mech: " + mech + ".")
109106
}
110-
else if (weightsProvided & mech == "MCAR") {
107+
else if (!u_isEmpty(weights) & mech == "MCAR") {
111108
print("ampute warning: User-provided weights will be ignored when mechanism MCAR is chosen.")
112109
}
113110

@@ -190,11 +187,10 @@ return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
190187

191188
# Weights: Various defaults based on chosen missingness mechanism:
192189
numPatterns = nrow(patterns)
193-
empty = u_isEmpty(weights) # FIX ME
194190
if (mech == "MCAR") {
195191
weights = matrix(0, rows=numPatterns, cols=numFeatures) # MCAR: All 0's (weights don't matter). Overrides any provided weights.
196192
}
197-
else if (empty) { # FIX ME
193+
else if (u_isEmpty(weights)) {
198194
if (mech == "MAR") {
199195
weights = patterns # MAR: Missing features weighted with 0.
200196
}

src/main/java/org/apache/sysds/parser/StatementBlock.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ else if (current instanceof IfStatementBlock) {
520520
istmt.setElseBody(tmp2);
521521
}
522522
if( !tmpPred.isEmpty() )
523-
return createStatementBlocks(current, tmpPred);
523+
return createStatementBlocks(current, tmpPred, true);
524524
}
525525
else if (current instanceof ForStatementBlock) { //incl parfor
526526
ForStatementBlock fsb = (ForStatementBlock) current;
@@ -659,6 +659,12 @@ private static DataIdentifier copy(DataIdentifier di) {
659659
}
660660

661661
private static List<StatementBlock> createStatementBlocks(StatementBlock sb, List<Statement> stmts) {
662+
return createStatementBlocks(sb, stmts, false);
663+
}
664+
665+
private static List<StatementBlock> createStatementBlocks(
666+
StatementBlock sb, List<Statement> stmts, boolean includeSb)
667+
{
662668
List<StatementBlock> ret = new ArrayList<>();
663669
StatementBlock current = new StatementBlock(sb);
664670
for(Statement stmt : stmts) {
@@ -679,6 +685,8 @@ private static List<StatementBlock> createStatementBlocks(StatementBlock sb, Lis
679685
}
680686
if( current.getNumStatements() > 0 )
681687
ret.add(current);
688+
if( includeSb ) // e.g., if block
689+
ret.add(sb);
682690
return ret;
683691
}
684692

src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
6363
"FunPotpourriParforEvalBuiltin",
6464
"FunPotpourriParforEvalSpark",
6565
"FunPotpourriEvalNamespace3",
66+
"FunPotpourriDefaultParams",
6667
};
6768

6869
private final static String TEST_DIR = "functions/misc/";
@@ -277,6 +278,11 @@ public void testFunctionEvalNamespace3() {
277278
runFunctionTest( TEST_NAMES[29], null, false );
278279
}
279280

281+
@Test
282+
public void testFunctionDefaultParams() {
283+
runFunctionTest( TEST_NAMES[30], null, false );
284+
}
285+
280286
private void runFunctionTest(String testName, Class<?> error) {
281287
runFunctionTest(testName, error, false);
282288
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
foo = function(Matrix[Double] X, Matrix[Double] h=matrix(0,0,0))
23+
return (Matrix[Double] X2)
24+
{
25+
print(nrow(h)+" "+ncol(h)+" "+length(h));
26+
#if( length(h)==0 ) # worked
27+
if( isEmpty(h) ) # didn't work
28+
h = 2 * X;
29+
X2 = h - X;
30+
}
31+
32+
isEmpty = function(Matrix[Double] h)
33+
return (boolean R)
34+
{
35+
R = (length(h) == 0);
36+
}
37+
38+
39+
X = rand(rows=100, cols=100, seed=7)
40+
41+
X2 = foo(X=X);
42+
43+
if( sum((X-X2)>1e-10)>0 )
44+
stop("Incorrect results: "+sum(X)+" "+sum(X2))
45+

0 commit comments

Comments
 (0)