Skip to content

Commit 159cc8e

Browse files
committed
[SYSTEMDS-3777] Improved adasyn builtin (tests, vectorized impl)
This patch adds real-data tests for the new adasyn builtin function, and changes the implementation to a vectorized implementation that extracts over-sampled rows via a randomized permutation matrix multiply. On the Diabetes dataset (with moderate class imbalance of 500 vs 268) ADASYN slightly improves the test accuracy from 78.3 to 78.7%. It is also noteworthy that the original ADASYN paper from 2008 only achieved 0.6831 and 0.6833 (with ADASYN) on this dataset.
1 parent 6d4eddf commit 159cc8e

File tree

5 files changed

+130
-105
lines changed

5 files changed

+130
-105
lines changed

scripts/builtin/adasyn.dml

Lines changed: 29 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,95 +24,60 @@
2424
#
2525
# INPUT:
2626
# --------------------------------------------------------------------------------------
27-
# minority Matrix of minority class samples
28-
# majority Matrix of majority class samples
29-
# k Number of nearest neighbors
30-
# beta Desired balance level after generation of synthetic data [0, 1]
27+
# X Feature matrix [shape: n-by-m]
28+
# Y Class labels [shape: n-by-1]
29+
# k Number of nearest neighbors
30+
# beta Desired balance level after generation of synthetic data [0, 1]
31+
# dth Distribution threshold
3132
# --------------------------------------------------------------------------------------
3233
#
3334
# OUTPUT:
3435
# -------------------------------------------------------------------------------------
35-
# Z Matrix of G synthetic minority class samples, with G = (ml-ms)*beta
36+
# Xp Feature matrix of n original rows followed by G = (ml-ms)*beta synthetic rows
37+
# Yp Class labels aligned with output X
3638
# -------------------------------------------------------------------------------------
3739

38-
m_adasyn = function(Matrix[Double] minority, Matrix[Double] majority, Integer k = 1, Double beta = 0.8)
39-
return (Matrix[Double] Z)
40+
m_adasyn = function(Matrix[Double] X, Matrix[Double] Y, Integer k = 2,
41+
Double beta = 1.0, Double dth = 0.9)
42+
return (Matrix[Double] Xp, Matrix[Double] Yp)
4043
{
4144
if(k < 1) {
4245
print("ADASYN: k should not be less than 1. Setting k value to default k = 1.")
4346
k = 1
4447
}
4548

4649
# Preprocessing
47-
dth = 0.9
48-
ms = nrow(minority)
49-
ml = nrow(majority)
50-
combined = rbind(minority, majority)
50+
freq = t(table(Y, 1));
51+
minorIdx = as.scalar(rowIndexMin(freq))
52+
majorIdx = as.scalar(rowIndexMax(freq))
5153

5254
# (Step 1)
5355
# Calculate the degree of class imbalance, where d in (0, 1]
54-
d = ms/ml
56+
d = as.scalar(freq[1,minorIdx])/sum(freq)
5557

5658
# (Step 2)
5759
# Check if imbalance is lower than predefined threshold
58-
if(d >= dth){
60+
print("ADASYN: class imbalance: " + d)
61+
62+
if(d >= dth) {
5963
stop("ADASYN: Class imbalance not large enough.")
6064
}
6165

6266
# (Step 2a)
6367
# Calculate number of synthetic data examples
64-
G = (ml-ms)*beta
68+
G = as.scalar(freq[1,majorIdx]-freq[1,minorIdx])*beta
6569

6670
# (Step 2b)
67-
# For each x_i in minority class, find k nearest neighbors.
68-
# Then, compute ratio r of neighbors belonging to majority class to total number of neighbors k
69-
NNR = knnbf(combined, minority, k+1)
70-
NNR = NNR[,2:ncol(NNR)]
71-
delta = rowSums(NNR>ms)
72-
r = delta/k
73-
r = r + 0 #only to force materialization, caught by compiler rewrites
74-
75-
# (Step 2c)
76-
# Normalize ratio vector r
77-
rSum = sum(r)
78-
r = r/rSum
79-
80-
# (Step 2d)
81-
# Calculate the number of synthetic data examples that need to be
82-
# generated for each minority example x_i
83-
# Then, pre-allocate the result matrix Z
84-
g = round(r * G)
85-
gSum = sum(g)
86-
Z = matrix(0, rows=gSum, cols=ncol(minority)) # output matrix, slightly overallocated
87-
88-
# (Step 2e)
89-
# For each minority class data example x_i, generate g_i synthetic data examples by
90-
# looping from 1 to g_i and randomly choosing one minority data example x_j from
91-
# the k-nearest neighbors. Then, compute the synthetic sample s_i as
92-
# s_i = x_i + (x_j - x_i) * lambda, with lambda being a random number in [0, 1].
93-
minNNR = NNR * (NNR <= ms) # set every index from majority class to zero
94-
zeroCount = 0
95-
for(i in 1:nrow(minority)){
96-
row = minNNR[i, ] # slice a row
97-
minRow = removeEmpty(target=row, margin="cols") # remove all zero values from that row
98-
hasSynthetic = as.scalar(g[i])>0
99-
hasMinorityNN = (as.scalar(minRow[1, 1]) > 0) & (hasSynthetic)
100-
if(hasMinorityNN){
101-
for(j in 1:as.scalar(g[i])){
102-
randomIndex = as.scalar(sample(ncol(minRow), 1))
103-
lambda = as.scalar(rand(rows=1, cols=1, min=0, max=1))
104-
randomMinIndex = as.scalar(minRow[ , randomIndex])
105-
randomMinNN = minority[randomMinIndex, ]
106-
insIdx = i+j-1-zeroCount
107-
Z[insIdx, ] = minority[i, ] + (randomMinNN - minority[i, ]) * lambda
108-
}
109-
} else {
110-
zeroCount = zeroCount + 1
111-
}
112-
}
113-
114-
diff = nrow(minority) - gSum
115-
numTrailZeros = zeroCount - diff
116-
Z = Z[1:gSum-numTrailZeros, ]
71+
# For each x_i in non-majority class, find k nearest neighbors.
72+
# Get G random points from the KNN set via a permutation matrix multiply
73+
Xnonmajor = removeEmpty(target=X, margin="rows", select=(Y!=majorIdx))
74+
Ynonmajor = removeEmpty(target=Y, margin="rows", select=(Y!=majorIdx))
75+
NNR = knnbf(Xnonmajor, Xnonmajor, k+1)
76+
NNR = matrix(NNR, rows=length(NNR), cols=1)
77+
I = rand(rows=nrow(NNR), cols=1) < (G/nrow(NNR))
78+
NNRg = removeEmpty(target=NNR, margin="rows", select=I);
79+
P = table(seq(1, nrow(NNRg)), NNRg, nrow(NNRg), nrow(Xnonmajor));
80+
Xp = rbind(X, P %*% Xnonmajor);
81+
Yp = rbind(Y, P %*% Ynonmajor); # multi-class
11782
}
11883

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part1;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.common.Types.ExecType;
24+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
25+
import org.apache.sysds.test.AutomatedTestBase;
26+
import org.apache.sysds.test.TestConfiguration;
27+
import org.apache.sysds.utils.Statistics;
28+
import org.junit.Assert;
29+
import org.junit.Test;
30+
31+
public class BuiltinAdasynRealDataTest extends AutomatedTestBase {
32+
private final static String TEST_NAME = "adasynRealData";
33+
private final static String TEST_DIR = "functions/builtin/";
34+
private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinAdasynRealDataTest.class.getSimpleName() + "/";
35+
36+
private final static String DIABETES_DATA = DATASET_DIR + "diabetes/diabetes.csv";
37+
private final static String DIABETES_TFSPEC = DATASET_DIR + "diabetes/tfspec.json";
38+
39+
@Override
40+
public void setUp() {
41+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
42+
}
43+
44+
@Test
45+
public void testDiabetesNoAdasyn() {
46+
runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, false, 0.783, -1, ExecType.CP);
47+
}
48+
49+
@Test
50+
public void testDiabetesAdasynK4() {
51+
runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 4, ExecType.CP);
52+
}
53+
54+
@Test
55+
public void testDiabetesAdasynK6() {
56+
runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 6, ExecType.CP);
57+
}
58+
59+
private void runAdasynTest(String data, String tfspec, boolean adasyn, double minAcc, int k, ExecType instType) {
60+
Types.ExecMode platformOld = setExecMode(instType);
61+
try {
62+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
63+
64+
String HOME = SCRIPT_DIR + TEST_DIR;
65+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
66+
programArgs = new String[] {"-stats",
67+
"-args", data, String.valueOf(adasyn), String.valueOf(k), output("R")};
68+
69+
runTest(true, false, null, -1);
70+
71+
double acc = readDMLMatrixFromOutputDir("R").get(new CellIndex(1,1));
72+
Assert.assertTrue(acc >= minAcc);
73+
Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
74+
}
75+
finally {
76+
rtplatform = platformOld;
77+
}
78+
}
79+
}

src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/test/resources/datasets/diabetes/diabetes.json

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/test/scripts/functions/builtin/adasyn.dml renamed to src/test/scripts/functions/builtin/adasynRealData.dml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,25 @@
1717
# specific language governing permissions and limitations
1818
# under the License.
1919
#
20-
#--------------------------------------
20+
#-------------------------------------------------------------
21+
22+
23+
M = read($1, data_type="matrix", format="csv", header=TRUE);
24+
Y = M[, ncol(M)] + 1
25+
X = M[, 1:ncol(M)-1]
26+
upsample = as.logical($2)
27+
28+
[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7);
29+
30+
if( upsample ) {
31+
# oversampling all classes other than majority
32+
[Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3);
33+
}
34+
35+
B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);
36+
[P,yhat,acc] = multiLogRegPredict(X=Xtest, Y=Ytest, B=B);
37+
print("accuracy: "+acc)
38+
39+
R = as.matrix(acc/100);
40+
write(R, $4);
41+

0 commit comments

Comments
 (0)