Skip to content

Commit 6d4eddf

Browse files
ReneEnjilianmboehm7
authored andcommitted
[SYSTEMDS-3777] New adasyn builtin function for TPCx-AI
Closes #2133.
1 parent 7a9ecff commit 6d4eddf

File tree

6 files changed

+948
-0
lines changed

6 files changed

+948
-0
lines changed

scripts/builtin/adasyn.dml

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Builtin function for handing class imbalance using Adaptive Synthetic Sampling (ADASYN)
23+
# by Haibo He et. al. In International Joint Conference on Neural Networks (2008). 1322-1328
24+
#
25+
# INPUT:
26+
# --------------------------------------------------------------------------------------
27+
# minority Matrix of minority class samples
28+
# majority Matrix of majority class samples
29+
# k Number of nearest neighbors
30+
# beta Desired balance level after generation of synthetic data [0, 1]
31+
# --------------------------------------------------------------------------------------
32+
#
33+
# OUTPUT:
34+
# -------------------------------------------------------------------------------------
35+
# Z Matrix of G synthetic minority class samples, with G = (ml-ms)*beta
36+
# -------------------------------------------------------------------------------------
37+
38+
m_adasyn = function(Matrix[Double] minority, Matrix[Double] majority, Integer k = 1, Double beta = 0.8)
39+
return (Matrix[Double] Z)
40+
{
41+
if(k < 1) {
42+
print("ADASYN: k should not be less than 1. Setting k value to default k = 1.")
43+
k = 1
44+
}
45+
46+
# Preprocessing
47+
dth = 0.9
48+
ms = nrow(minority)
49+
ml = nrow(majority)
50+
combined = rbind(minority, majority)
51+
52+
# (Step 1)
53+
# Calculate the degree of class imbalance, where d in (0, 1]
54+
d = ms/ml
55+
56+
# (Step 2)
57+
# Check if imbalance is lower than predefined threshold
58+
if(d >= dth){
59+
stop("ADASYN: Class imbalance not large enough.")
60+
}
61+
62+
# (Step 2a)
63+
# Calculate number of synthetic data examples
64+
G = (ml-ms)*beta
65+
66+
# (Step 2b)
67+
# For each x_i in minority class, find k nearest neighbors.
68+
# Then, compute ratio r of neighbors belonging to majority class to total number of neighbors k
69+
NNR = knnbf(combined, minority, k+1)
70+
NNR = NNR[,2:ncol(NNR)]
71+
delta = rowSums(NNR>ms)
72+
r = delta/k
73+
r = r + 0 #only to force materialization, caught by compiler rewrites
74+
75+
# (Step 2c)
76+
# Normalize ratio vector r
77+
rSum = sum(r)
78+
r = r/rSum
79+
80+
# (Step 2d)
81+
# Calculate the number of synthetic data examples that need to be
82+
# generated for each minority example x_i
83+
# Then, pre-allocate the result matrix Z
84+
g = round(r * G)
85+
gSum = sum(g)
86+
Z = matrix(0, rows=gSum, cols=ncol(minority)) # output matrix, slightly overallocated
87+
88+
# (Step 2e)
89+
# For each minority class data example x_i, generate g_i synthetic data examples by
90+
# looping from 1 to g_i and randomly choosing one minority data example x_j from
91+
# the k-nearest neighbors. Then, compute the synthetic sample s_i as
92+
# s_i = x_i + (x_j - x_i) * lambda, with lambda being a random number in [0, 1].
93+
minNNR = NNR * (NNR <= ms) # set every index from majority class to zero
94+
zeroCount = 0
95+
for(i in 1:nrow(minority)){
96+
row = minNNR[i, ] # slice a row
97+
minRow = removeEmpty(target=row, margin="cols") # remove all zero values from that row
98+
hasSynthetic = as.scalar(g[i])>0
99+
hasMinorityNN = (as.scalar(minRow[1, 1]) > 0) & (hasSynthetic)
100+
if(hasMinorityNN){
101+
for(j in 1:as.scalar(g[i])){
102+
randomIndex = as.scalar(sample(ncol(minRow), 1))
103+
lambda = as.scalar(rand(rows=1, cols=1, min=0, max=1))
104+
randomMinIndex = as.scalar(minRow[ , randomIndex])
105+
randomMinNN = minority[randomMinIndex, ]
106+
insIdx = i+j-1-zeroCount
107+
Z[insIdx, ] = minority[i, ] + (randomMinNN - minority[i, ]) * lambda
108+
}
109+
} else {
110+
zeroCount = zeroCount + 1
111+
}
112+
}
113+
114+
diff = nrow(minority) - gSum
115+
numTrailZeros = zeroCount - diff
116+
Z = Z[1:gSum-numTrailZeros, ]
117+
}
118+

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public enum Builtins {
4141
ABSTAIN("abstain", true),
4242
ABS("abs", false),
4343
ACOS("acos", false),
44+
ADASYN("adasyn", true),
4445
ALS("als", true),
4546
ALS_CG("alsCG", true),
4647
ALS_DS("alsDS", true),
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part1;
21+
22+
public class BuiltinAdasynTest {
23+
}

0 commit comments

Comments
 (0)