Skip to content

Commit 7ebf913

Browse files
committed
[SYSTEMDS-3868] Fix missing function hoisting from if predicates
This patch adds the missing hoisting of DML function calls (which always need to bind to variables) from basic if predicates for convenience and in order to prevent unexpected errors. Furthermore, this patch simplifies the existing DML-bodied ampute() builtin by using this features as well as call the existing sigmoid() instead of a custom one.
1 parent 3ce16d0 commit 7ebf913

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

scripts/builtin/ampute.dml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
184184
u_handleDefaults = function(Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights, String mech, Integer numFeatures)
185185
return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
186186
# Patterns: Default is a quadratic matrix wherein pattern i amputes feature i.
187-
empty = u_isEmpty(patterns)
188-
if (empty) { # FIX ME
187+
if (u_isEmpty(patterns)) {
189188
patterns = matrix(1, rows=numFeatures, cols=numFeatures) - diag(matrix(1, rows=numFeatures, cols=1))
190189
}
191190

@@ -205,8 +204,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
205204
}
206205

207206
# Frequencies: Uniform by default.
208-
empty = u_isEmpty(freq) # FIX ME
209-
if (empty) {
207+
if (u_isEmpty(freq)) {
210208
freq = matrix(1 / numPatterns, rows=numPatterns, cols=1)
211209
}
212210
}
@@ -282,7 +280,7 @@ return (Matrix[Double] probsArray) {
282280
while (counter < maxIter & (is.na(currentProb) | abs(currentProb - prop) >= epsilon)) {
283281
counter += 1
284282
shift = lowerRange + (upperRange - lowerRange) / 2
285-
probsArray = u_sigmoid(zScores + shift) # Calculates Right-Sigmoid probability (R implementation's default).
283+
probsArray = sigmoid(zScores + shift) # Calculates Right-Sigmoid probability (R implementation's default).
286284
currentProb = mean(probsArray)
287285
if (currentProb - prop > 0) {
288286
upperRange = shift
@@ -293,11 +291,6 @@ return (Matrix[Double] probsArray) {
293291
}
294292
}
295293

296-
u_sigmoid = function(Matrix[Double] X)
297-
return (Matrix[Double] sigmoided) {
298-
sigmoided = 1 / (1 + exp(-X))
299-
}
300-
301294
u_getBounds = function(Matrix[Double] numPerGroup, Integer groupSize, Integer patternNum)
302295
return(Integer start, Integer end) {
303296
if (patternNum == 1) {

src/main/java/org/apache/sysds/parser/StatementBlock.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,12 @@ else if (current instanceof WhileStatementBlock) {
503503
else if (current instanceof IfStatementBlock) {
504504
IfStatementBlock isb = (IfStatementBlock) current;
505505
IfStatement istmt = (IfStatement)isb.getStatement(0);
506-
//TODO handle predicates
506+
//handle predicate
507+
ArrayList<Statement> tmpPred = new ArrayList<>();
508+
istmt.getConditionalPredicate().setPredicate(
509+
rHoistFunctionCallsFromExpressions(
510+
istmt.getConditionalPredicate().getPredicate(), false, tmpPred, prog));
511+
//handle if and else body
507512
ArrayList<StatementBlock> tmp = new ArrayList<>();
508513
for (StatementBlock sb : istmt.getIfBody())
509514
tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog));
@@ -514,6 +519,8 @@ else if (current instanceof IfStatementBlock) {
514519
tmp2.addAll(rHoistFunctionCallsFromExpressions(sb, prog));
515520
istmt.setElseBody(tmp2);
516521
}
522+
if( !tmpPred.isEmpty() )
523+
return createStatementBlocks(current, tmpPred);
517524
}
518525
else if (current instanceof ForStatementBlock) { //incl parfor
519526
ForStatementBlock fsb = (ForStatementBlock) current;

0 commit comments

Comments
 (0)