Skip to content

Commit 8edabbc

Browse files
committed
[SYSTEMDS-3777] Fix adasyn test flakiness via fixed seeds
1 parent 159cc8e commit 8edabbc

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

scripts/builtin/adasyn.dml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# k Number of nearest neighbors
3030
# beta Desired balance level after generation of synthetic data [0, 1]
3131
# dth Distribution threshold
32+
# seed Seed for randomized data point selection
3233
# --------------------------------------------------------------------------------------
3334
#
3435
# OUTPUT:
@@ -38,7 +39,7 @@
3839
# -------------------------------------------------------------------------------------
3940

4041
m_adasyn = function(Matrix[Double] X, Matrix[Double] Y, Integer k = 2,
41-
Double beta = 1.0, Double dth = 0.9)
42+
Double beta = 1.0, Double dth = 0.9, Integer seed = -1)
4243
return (Matrix[Double] Xp, Matrix[Double] Yp)
4344
{
4445
if(k < 1) {
@@ -74,7 +75,7 @@ m_adasyn = function(Matrix[Double] X, Matrix[Double] Y, Integer k = 2,
7475
Ynonmajor = removeEmpty(target=Y, margin="rows", select=(Y!=majorIdx))
7576
NNR = knnbf(Xnonmajor, Xnonmajor, k+1)
7677
NNR = matrix(NNR, rows=length(NNR), cols=1)
77-
I = rand(rows=nrow(NNR), cols=1) < (G/nrow(NNR))
78+
I = rand(rows=nrow(NNR), cols=1, seed=seed) < (G/nrow(NNR))
7879
NNRg = removeEmpty(target=NNR, margin="rows", select=I);
7980
P = table(seq(1, nrow(NNRg)), NNRg, nrow(NNRg), nrow(Xnonmajor));
8081
Xp = rbind(X, P %*% Xnonmajor);

src/test/scripts/functions/builtin/adasynRealData.dml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ Y = M[, ncol(M)] + 1
2525
X = M[, 1:ncol(M)-1]
2626
upsample = as.logical($2)
2727

28-
[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7);
28+
[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7, seed=3);
2929

3030
if( upsample ) {
3131
# oversampling all classes other than majority
32-
[Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3);
32+
[Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3, seed=7);
3333
}
3434

3535
B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);

0 commit comments

Comments
 (0)