Skip to content

Commit d78165b

Browse files
KilianBatmboehm7
authored andcommitted
[SYSTEMDS-3863] New robust scaling built-in function
Closes #2278.
1 parent 075392e commit d78165b

File tree

8 files changed

+298
-1
lines changed

8 files changed

+298
-1
lines changed

scripts/builtin/scaleRobust.dml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Robust scaling using median and IQR (Interquartile Range)
23+
# Resistant to outliers by centering with the median and scaling with IQR.
24+
#
25+
# INPUT:
26+
# -------------------------------------------------------------------------------------
27+
# X Input feature matrix of shape n-by-m
28+
# -------------------------------------------------------------------------------------
29+
#
30+
# OUTPUT:
31+
# -------------------------------------------------------------------------------------
32+
# Y Scaled output matrix of shape n-by-m
33+
# med Column medians (Q2) of shape 1-by-m
34+
# q1 Column first quantiles (Q1) of shape 1-by-m
35+
# q3 Column first quantiles (Q3) of shape 1-by-m
36+
# -------------------------------------------------------------------------------------
37+
38+
m_scaleRobust = function(Matrix[Double] X)
39+
return (Matrix[Double] Y, Matrix[Double] med, Matrix[Double] q1, Matrix[Double] q3)
40+
{
41+
n = nrow(X)
42+
m = ncol(X)
43+
44+
med = matrix(0.0, rows=1, cols=m)
45+
q1 = matrix(0.0, rows=1, cols=m)
46+
q3 = matrix(0.0, rows=1, cols=m)
47+
48+
# Define quantile probabilities once, outside the loop
49+
q_probs = as.matrix(list(0.25, 0.5, 0.75));
50+
51+
# Loop over columns to compute quantiles
52+
parfor (j in 1:m) {
53+
q = quantile(X[,j], q_probs)
54+
med[1,j] = q[2,1]
55+
q1[1,j] = q[1,1]
56+
q3[1,j] = q[3,1]
57+
}
58+
59+
Y = scaleRobustApply(X, med, q1, q3);
60+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Apply robust scaling using precomputed medians and IQRs
23+
#
24+
# INPUT:
25+
# -------------------------------------------------------------------------------------
26+
# X Input feature matrix of shape n-by-m
27+
# med Column medians (Q2) of shape 1-by-m
28+
# q1 Column first quantiles (Q1) of shape 1-by-m
29+
# q3 Column first quantiles (Q3) of shape 1-by-m
30+
# -------------------------------------------------------------------------------------
31+
#
32+
# OUTPUT:
33+
# -------------------------------------------------------------------------------------
34+
# Y Scaled output matrix of shape n-by-m
35+
# -------------------------------------------------------------------------------------
36+
37+
m_scaleRobustApply = function(Matrix[Double] X, Matrix[Double] med, Matrix[Double] q1, Matrix[Double] q3)
38+
return (Matrix[Double] Y)
39+
{
40+
iqr = q3 - q1
41+
42+
# Ensure robust scaling is safe by replacing invalid IQRs
43+
iqr = replace(target=iqr, pattern=0, replacement=1)
44+
iqr = replace(target=iqr, pattern=NaN, replacement=1)
45+
46+
# Apply robust transformation
47+
Y = (X - med) / iqr
48+
}

src/main/java/org/apache/sysds/api/DMLScript.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ public static boolean executeScript( String[] args )
327327
Map<String, String> argVals = dmlOptions.argVals;
328328

329329
DML_FILE_PATH_ANTLR_PARSER = dmlOptions.filePath;
330-
330+
331331
//Step 3: invoke dml script
332332
printInvocationInfo(fileOrScript, fnameOptConfig, argVals);
333333
execute(dmlScriptStr, fnameOptConfig, argVals, args);

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ public enum Builtins {
392392
RMEMPTY("removeEmpty", false, true),
393393
SCALE("scale", true, false),
394394
SCALEAPPLY("scaleApply", true, false),
395+
SCALEROBUST("scaleRobust", true, false),
396+
SCALEROBUSTAPPLY("scaleRobustApply", true, false),
395397
SCALE_MINMAX("scaleMinMax", true, false),
396398
TIME("time", false),
397399
TOKENIZE("tokenize", false, true),
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part2;
21+
22+
import java.util.HashMap;
23+
24+
import org.junit.Test;
25+
26+
import org.apache.sysds.common.Types.ExecMode;
27+
import org.apache.sysds.common.Types.ExecType;
28+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
29+
import org.apache.sysds.test.AutomatedTestBase;
30+
import org.apache.sysds.test.TestConfiguration;
31+
import org.apache.sysds.test.TestUtils;
32+
33+
public class BuiltinScaleRobustTest extends AutomatedTestBase {
34+
private final static String TEST_NAME = "scaleRobust";
35+
private final static String TEST_DIR = "functions/builtin/";
36+
private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinScaleRobustTest.class.getSimpleName() + "/";
37+
private final static double eps = 1e-10;
38+
private final static int rows = 70;
39+
private final static int cols = 50;
40+
41+
42+
@Override
43+
public void setUp() {
44+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"}));
45+
}
46+
47+
@Test
48+
public void testScaleRobustDenseCP() {
49+
runTest(false, ExecType.CP);
50+
}
51+
52+
private void runTest(boolean sparse, ExecType et) {
53+
ExecMode old = setExecMode(et);
54+
try {
55+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
56+
double sparsity = sparse ? 0.1 : 0.9;
57+
String HOME = SCRIPT_DIR + TEST_DIR;
58+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
59+
fullRScriptName = HOME + TEST_NAME + ".R";
60+
programArgs = new String[]{"-args", input("A"), output("B")};
61+
programArgs = new String[]{"-exec", "singlenode", "-args", input("A"), output("B")};
62+
rCmd = "Rscript " + fullRScriptName + " " + inputDir() + " " + expectedDir();
63+
64+
double[][] A = getRandomMatrix(rows, cols, -10, 10, sparsity, 7);
65+
writeInputMatrixWithMTD("A", A, true);
66+
67+
// Run DML
68+
runTest(true, false, null, -1);
69+
70+
// Run R
71+
runRScript(true);
72+
73+
// Read matrices and compare
74+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B");
75+
HashMap<CellIndex, Double> rfile = readRMatrixFromExpectedDir("B");
76+
TestUtils.compareMatrices(dmlfile, rfile, eps, "DML", "R");
77+
} catch (Exception e) {
78+
throw new RuntimeException(e);
79+
} finally {
80+
resetExecMode(old);
81+
}
82+
}
83+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
library("Matrix")
23+
24+
args <- commandArgs(TRUE)
25+
options(digits=22)
26+
27+
28+
X = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
29+
colnames(X) = colnames(X, do.NULL=FALSE, prefix="C")
30+
Y = X
31+
32+
for (j in 1:ncol(X)) {
33+
col = X[, j]
34+
med = quantile(col, probs=0.5, type=1, names=FALSE, na.rm=FALSE)
35+
q1 = quantile(col, probs=0.25, type=1, names=FALSE, na.rm=FALSE)
36+
q3 = quantile(col, probs=0.75, type=1, names=FALSE, na.rm=FALSE)
37+
iqr = q3 - q1
38+
if (iqr == 0 || is.nan(iqr)) iqr = 1
39+
Y[, j] = (col - med) / iqr
40+
}
41+
42+
writeMM(as(Y, "CsparseMatrix"), paste(args[2], "B", sep=""))
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = read($1);
23+
[Y, med, iqr] = scaleRobust(X);
24+
write(Y, $2);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
import sys
23+
import numpy as np
24+
from scipy.io import mmread, mmwrite
25+
from scipy.sparse import csc_matrix
26+
from sklearn.preprocessing import RobustScaler
27+
28+
if __name__ == "__main__":
29+
input_path = sys.argv[1] + "A.mtx"
30+
output_path = sys.argv[2] + "B"
31+
32+
X = mmread(input_path).toarray()
33+
34+
# Apply RobustScaler
35+
scaler = RobustScaler()
36+
Y = scaler.fit_transform(X)
37+
38+
mmwrite(output_path, csc_matrix(Y))

0 commit comments

Comments
 (0)