Skip to content

Commit 3ae6a77

Browse files
committed
[SYSTEMDS-3898] Fix correctness CP quantile pick instruction
This patch fixes an issue with quantiles for even-length arrays, where the median for example is not a picked value but an average over two values. As it turns out, the quantile kernel already supported averaging but was called incorrectly for quantile() but correctly for median(). Now we have equivalent results to R and consistency in terms of quantile(X, 0.5) == median(X). The distributed Spark operations and weighted kernels need some additional thought and a more involved implementation. Thanks to Ramon Schoendorf for catching and reporting this issue.
1 parent 8bed176 commit 3ae6a77

File tree

5 files changed

+91
-24
lines changed

5 files changed

+91
-24
lines changed

src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,16 @@ public void processInstruction(ExecutionContext ec) {
9090

9191
if ( input2.getDataType() == DataType.SCALAR ) {
9292
ScalarObject quantile = ec.getScalarInput(input2);
93-
double picked = matBlock.pickValue(quantile.getDoubleValue());
93+
//pick value w/ explicit averaging for even-length arrays
94+
double picked = matBlock.pickValue(
95+
quantile.getDoubleValue(), matBlock.getLength()%2==0);
9496
ec.setScalarOutput(output.getName(), new DoubleObject(picked));
9597
}
9698
else {
9799
MatrixBlock quantiles = ec.getMatrixInput(input2.getName());
98-
MatrixBlock resultBlock = matBlock.pickValues(quantiles, new MatrixBlock());
100+
//pick value w/ explicit averaging for even-length arrays
101+
MatrixBlock resultBlock = matBlock.pickValues(
102+
quantiles, new MatrixBlock(), matBlock.getLength()%2==0);
99103
quantiles = null;
100104
ec.releaseMatrixInput(input2.getName());
101105
ec.setMatrixOutput(output.getName(), resultBlock);

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4755,7 +4755,10 @@ public static double computeIQMCorrection(double sum, double sum_wt,
47554755
}
47564756

47574757
public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) {
4758+
return pickValues(quantiles, ret, false);
4759+
}
47584760

4761+
public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret, boolean average) {
47594762
MatrixBlock qs=checkType(quantiles);
47604763

47614764
if ( qs.clen != 1 ) {

src/test/java/org/apache/sysds/test/functions/binary/matrix/QuantileTest.java

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,19 @@
2222
import java.util.HashMap;
2323

2424
import org.junit.Test;
25-
import org.apache.sysds.api.DMLScript;
2625
import org.apache.sysds.common.Types.ExecMode;
2726
import org.apache.sysds.common.Types.ExecType;
2827
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
2928
import org.apache.sysds.test.AutomatedTestBase;
3029
import org.apache.sysds.test.TestConfiguration;
3130
import org.apache.sysds.test.TestUtils;
3231

33-
/**
34-
*
35-
*/
3632
public class QuantileTest extends AutomatedTestBase
3733
{
38-
3934
private final static String TEST_NAME1 = "Quantile";
4035
private final static String TEST_NAME2 = "Median";
4136
private final static String TEST_NAME3 = "IQM";
37+
private final static String TEST_NAME4 = "QuantileBug";
4238

4339
private final static String TEST_DIR = "functions/binary/matrix/";
4440
private final static String TEST_CLASS_DIR = TEST_DIR + QuantileTest.class.getSimpleName() + "/";
@@ -59,6 +55,8 @@ public void setUp()
5955
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
6056
addTestConfiguration(TEST_NAME3,
6157
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
58+
addTestConfiguration(TEST_NAME4,
59+
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
6260
}
6361

6462
@Test
@@ -161,19 +159,21 @@ public void testIQMSparseSP() {
161159
runQuantileTest(TEST_NAME3, -1, true, ExecType.SPARK);
162160
}
163161

162+
@Test
163+
public void testQuantileBugCP() {
164+
runQuantileTest(TEST_NAME4, 0.5, false, ExecType.CP);
165+
}
166+
167+
// TODO reimplement distributed value pick logic
168+
// @Test
169+
// public void testQuantileBugSP() {
170+
// runQuantileTest(TEST_NAME4, 0.5, false, ExecType.SPARK);
171+
// }
172+
164173
private void runQuantileTest( String TEST_NAME, double p, boolean sparse, ExecType et)
165174
{
166-
//rtplatform for MR
167-
ExecMode platformOld = rtplatform;
168-
switch( et ){
169-
case SPARK: rtplatform = ExecMode.SPARK; break;
170-
default: rtplatform = ExecMode.HYBRID; break;
171-
}
175+
ExecMode platformOld = setExecMode(et);
172176

173-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
174-
if( rtplatform == ExecMode.SPARK )
175-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
176-
177177
try
178178
{
179179
getAndLoadTestConfiguration(TEST_NAME);
@@ -185,9 +185,11 @@ private void runQuantileTest( String TEST_NAME, double p, boolean sparse, ExecTy
185185
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + p + " "+ expectedDir();
186186

187187
//generate actual dataset (always dense because values <=0 invalid)
188-
double sparsitya = sparse ? sparsity2 : sparsity1;
189-
double[][] A = getRandomMatrix(rows, 1, 1, maxVal, sparsitya, 1236);
190-
writeInputMatrixWithMTD("A", A, true);
188+
if( !TEST_NAME.equals(TEST_NAME4) ) {
189+
double sparsitya = sparse ? sparsity2 : sparsity1;
190+
double[][] A = getRandomMatrix(rows, 1, 1, maxVal, sparsitya, 1236);
191+
writeInputMatrixWithMTD("A", A, true);
192+
}
191193

192194
runTest(true, false, null, -1);
193195
runRScript(true);
@@ -198,9 +200,7 @@ private void runQuantileTest( String TEST_NAME, double p, boolean sparse, ExecTy
198200
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
199201
}
200202
finally {
201-
rtplatform = platformOld;
202-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
203+
resetExecMode(platformOld);
203204
}
204205
}
205-
206-
}
206+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args <- commandArgs(TRUE)
23+
options(digits=22)
24+
25+
library("Matrix")
26+
27+
A = as.matrix(c(1,5,7,10))
28+
p = as.double(args[2]);
29+
30+
s = quantile(A, p);
31+
m = as.matrix(s);
32+
33+
writeMM(as(m, "CsparseMatrix"), paste(args[3], "R", sep=""));
34+
35+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
A = as.matrix(list(1,5,7,10));
23+
s = quantile(A, $2);
24+
m = as.matrix(s);
25+
write(m, $3, format="text");

0 commit comments

Comments
 (0)