Skip to content

Commit b8d373a

Browse files
committed
[SYSTEMDS-3853] Fix ampute outer broadcasting and error handling
This patch fixes an invalid left-hand-side and left- and right-hand-side broadcasting in the new ampute builtin function. We now have a proper error handling in the hop to guide script developers that broadcasts can only be used from the right-hand-side.
1 parent c4e7e46 commit b8d373a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

scripts/builtin/ampute.dml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ m_ampute = function(Matrix[Double] X,
7272

7373
# 4. Use probabilities to ampute pattern candidates:
7474
random = rand(rows=groupSize, cols=1, min=0, max=1, pdf="uniform", seed=seed)
75-
amputeds = (random <= probs) * (1 - patterns[patternNum]) # Obtains matrix with 1's at indices to ampute.
76-
while (FALSE) {} # FIX ME
75+
# Obtains matrix with 1's at indices to ampute.
76+
amputeds = outer((random <= probs), (1 - patterns[patternNum]), "*")
7777
groupSamples = groupSamples + replace(target=amputeds, pattern=1, replacement=NaN)
7878

7979
# 5. Update output matrix:
@@ -241,7 +241,6 @@ return (Matrix[Double] groupAssignments, Matrix[Double] groupCounts) {
241241

242242
for (i in 1:numGroups) {
243243
assigned = (random >= cumSum[i]) & (random < cumSum[i + 1])
244-
while (FALSE) {} # FIX ME
245244
groupCounts[i] = sum(assigned)
246245
groupAssignments = groupAssignments + i * assigned
247246
}
@@ -308,4 +307,4 @@ return(Integer start, Integer end) {
308307
start = sum(numPerGroup[1:(patternNum - 1), ]) + 1
309308
}
310309
end = start + groupSize - 1
311-
}
310+
}

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,15 @@ else if( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX )
10921092
}
10931093
else //GENERAL CASE
10941094
{
1095+
//check correct broadcasting dimensions
1096+
if( (input1.getDim1()==1 && input2.getDim1() > 1)
1097+
|| (input1.getDim2()==1 && input2.getDim2() > 1) )
1098+
{
1099+
throw new HopsException("Invalid binary broadcasting from left: "
1100+
+ input1.getDataCharacteristics()+" "+getOp().name()+" "
1101+
+input2.getDataCharacteristics());
1102+
}
1103+
10951104
ldim1 = (input1.rowsKnown()) ? input1.getDim1()
10961105
: ((input2.getDim1()>1)?input2.getDim1():-1);
10971106
ldim2 = (input1.colsKnown()) ? input1.getDim2()

0 commit comments

Comments
 (0)