Skip to content

Commit ef37bcb

Browse files
authored
Add BuiltinGloVeTest.java
- This test first runs the DML script to generate the top K most similar words for each word in the GloVe word embedding. - Then, it computes the accuracy of the DML results based on the hits of the most similar words for the entire vocabulary, comparing the expected results with the DML output. - To validate the correctness of our GloVe word embedding implementation, we employ a Controlled Overfitting Validation approach. - This methodology addresses the inherent challenge of testing stochastic algorithms, where random initialization typically prevents direct output comparison between different runs or implementations.
1 parent 179659c commit ef37bcb

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part1;
21+
22+
import java.io.IOException;
23+
import java.util.Objects;
24+
25+
import org.apache.sysds.common.Types;
26+
import org.apache.sysds.common.Types.FileFormat;
27+
import org.apache.sysds.runtime.frame.data.FrameBlock;
28+
import org.apache.sysds.test.AutomatedTestBase;
29+
import org.apache.sysds.test.TestConfiguration;
30+
import org.junit.Test;
31+
32+
public class BuiltinGloVeTest extends AutomatedTestBase {
33+
34+
private static final String TEST_NAME = "glove";
35+
private static final String TEST_DIR = "functions/builtin/";
36+
private static final String RESOURCE_DIRECTORY = "./src/test/resources/datasets/";
37+
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGloVeTest.class.getSimpleName() + "/";
38+
39+
private final static Types.ValueType[] schema = {Types.ValueType.STRING};
40+
41+
private static final int TOP_K = 5;
42+
private static final double ACCURACY_THRESHOLD = 0.85;
43+
44+
private static final double seed = 45;
45+
private static final double vector_size = 100;
46+
private static final double alpha = 0.75;
47+
private static final double eta = 0.05;
48+
private static final double x_max = 100;
49+
private static final double tol = 1e-4;
50+
private static final double iterations = 10000;
51+
private static final double print_loss_it = 100;
52+
private static final double maxTokens = 2600;
53+
private static final double windowSize = 15;
54+
private static final String distanceWeighting = "TRUE";
55+
private static final String symmetric = "TRUE";
56+
57+
@Override
58+
public void setUp() {
59+
addTestConfiguration(TEST_NAME,
60+
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"out_result"}));
61+
}
62+
63+
@Test
64+
public void gloveTest() throws IOException{
65+
// Using top-5 words for similarity comparison
66+
runGloVe(TOP_K);
67+
68+
// Read the computed similarity results from SystemDS
69+
FrameBlock computedSimilarity = readDMLFrameFromHDFS("out_result", FileFormat.CSV);
70+
71+
// Load expected results (precomputed in Python)
72+
FrameBlock expectedSimilarity = readDMLFrameFromHDFS(RESOURCE_DIRECTORY + "/GloVe/gloveExpectedTop10.csv", FileFormat.CSV, false);
73+
74+
// Compute accuracy by comparing computed and expected results
75+
double accuracy = computeAccuracy(computedSimilarity, expectedSimilarity, TOP_K);
76+
77+
System.out.println("Computed Accuracy: " + accuracy);
78+
79+
// Ensure accuracy is above a reasonable threshold
80+
assert accuracy > ACCURACY_THRESHOLD : "Accuracy too low! Expected > 85% match.";
81+
}
82+
83+
public void runGloVe(int topK) {
84+
// Load test configuration
85+
Types.ExecMode platformOld = setExecMode(Types.ExecType.CP);
86+
try {
87+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
88+
89+
String HOME = SCRIPT_DIR + TEST_DIR;
90+
91+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
92+
93+
programArgs = new String[] {
94+
"-nvargs",
95+
"input=" + RESOURCE_DIRECTORY + "20news/20news_subset_untokenized.csv",
96+
"seed=" + seed,
97+
"vector_size=" + vector_size,
98+
"alpha=" + alpha,
99+
"eta=" + eta,
100+
"x_max=" + x_max,
101+
"tol=" + tol,
102+
"iterations=" + iterations,
103+
"print_loss_it=" + print_loss_it,
104+
"maxTokens=" + maxTokens,
105+
"windowSize=" + windowSize,
106+
"distanceWeighting=" + distanceWeighting,
107+
"symmetric=" + symmetric,
108+
"topK=" + topK,
109+
"out_result=" + output("out_result")
110+
};
111+
112+
System.out.println("Running DML script...");
113+
runTest(true, false, null, -1);
114+
System.out.println("Test execution completed.");
115+
} finally {
116+
rtplatform = platformOld;
117+
}
118+
}
119+
120+
/**
121+
* Computes accuracy by comparing top-K similar words between computed and expected results.
122+
*/
123+
private double computeAccuracy(FrameBlock computed, FrameBlock expected, int k) {
124+
int count = 0;
125+
for (int i = 0; i < computed.getNumRows(); i++) {
126+
int matchCount = 0;
127+
for (int j = 1; j < k; j++) {
128+
String word1 = computed.getString(i, j);
129+
for (int m = 0; m < k; m++) {
130+
if (Objects.equals(word1, expected.getString(i, m))) {
131+
matchCount++;
132+
break;
133+
}
134+
}
135+
}
136+
if (matchCount > 0) {
137+
count++;
138+
}
139+
}
140+
return (double) count / computed.getNumRows();
141+
}
142+
}

0 commit comments

Comments
 (0)