Skip to content

Commit 066c0aa

Browse files
committed
[SYSTEMDS-3777] Additional adasyn real data tests
This generalizes the adasyn test for additional real data set. On the titantic dataset, adasyn gives a 1.6% improvement of test accuracy (for a basic logreg model, 0.781 -> 0.797).
1 parent 8edabbc commit 066c0aa

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1610,7 +1610,7 @@ else if (source instanceof ExpressionList){
16101610
return currBuiltinOp;
16111611
}
16121612
else{
1613-
throw new ParseException("Unhandled instance of source type: " + source.getClass());
1613+
throw new ParseException("Unhandled instance of source type: " + source);
16141614
}
16151615
}
16161616
catch(ParseException e ){

src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ public class BuiltinAdasynRealDataTest extends AutomatedTestBase {
3535

3636
private final static String DIABETES_DATA = DATASET_DIR + "diabetes/diabetes.csv";
3737
private final static String DIABETES_TFSPEC = DATASET_DIR + "diabetes/tfspec.json";
38+
private final static String TITANIC_DATA = DATASET_DIR + "titanic/titanic.csv";
39+
private final static String TITANIC_TFSPEC = DATASET_DIR + "titanic/tfspec.json";
3840

3941
@Override
4042
public void setUp() {
@@ -56,15 +58,30 @@ public void testDiabetesAdasynK6() {
5658
runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 6, ExecType.CP);
5759
}
5860

61+
@Test
62+
public void testTitanicNoAdasyn() {
63+
runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, false, 0.781, -1, ExecType.CP);
64+
}
65+
66+
@Test
67+
public void testTitanicAdasynK4() {
68+
runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, true, 0.797, 4, ExecType.CP);
69+
}
70+
71+
@Test
72+
public void testTitanicAdasynK5() {
73+
runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, true, 0.797, 5, ExecType.CP);
74+
}
75+
5976
private void runAdasynTest(String data, String tfspec, boolean adasyn, double minAcc, int k, ExecType instType) {
6077
Types.ExecMode platformOld = setExecMode(instType);
6178
try {
6279
loadTestConfiguration(getTestConfiguration(TEST_NAME));
6380

6481
String HOME = SCRIPT_DIR + TEST_DIR;
6582
fullDMLScriptName = HOME + TEST_NAME + ".dml";
66-
programArgs = new String[] {"-stats",
67-
"-args", data, String.valueOf(adasyn), String.valueOf(k), output("R")};
83+
programArgs = new String[] {"-stats", "-args",
84+
data, tfspec, String.valueOf(adasyn), String.valueOf(k), output("R")};
6885

6986
runTest(true, false, null, -1);
7087

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/test/scripts/functions/builtin/adasynRealData.dml

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,36 @@
2020
#-------------------------------------------------------------
2121

2222

23-
M = read($1, data_type="matrix", format="csv", header=TRUE);
24-
Y = M[, ncol(M)] + 1
25-
X = M[, 1:ncol(M)-1]
26-
upsample = as.logical($2)
23+
M = read($1, data_type="frame", format="csv", header=TRUE,
24+
naStrings= ["NA", "null"," ","NaN", "nan", "", " ", "_nan_", "inf", "?", "NAN", "99999", "99999.00"]);
25+
Y = as.matrix(M[, ncol(M)]) + 1
26+
F = M[, 1:ncol(M)-1]
27+
tfspec = read($2, data_type="scalar", value_type="string")
28+
upsample = as.logical($3)
29+
30+
if( tfspec != " " ) {
31+
F = M[, 1:ncol(M)] # FIXME
32+
[X,meta] = transformencode(target=F, spec=tfspec);
33+
X = X[,1:ncol(X)-1];
34+
X = imputeByMode(X);
35+
}
36+
else {
37+
X = as.matrix(F);
38+
}
39+
40+
[X,C,S] = scale(X=X, scale=TRUE, center=TRUE);
2741

2842
[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7, seed=3);
2943

3044
if( upsample ) {
3145
# oversampling all classes other than majority
32-
[Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3, seed=7);
46+
[Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$4, seed=7);
3347
}
3448

3549
B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);
3650
[P,yhat,acc] = multiLogRegPredict(X=Xtest, Y=Ytest, B=B);
3751
print("accuracy: "+acc)
3852

3953
R = as.matrix(acc/100);
40-
write(R, $4);
54+
write(R, $5);
4155

0 commit comments

Comments
 (0)