Skip to content

Commit dda11b6

Browse files
authored
Add test for cooccurrence Matrix script.
The test checks the result of the cooccurrenceMatrix.dml for a small dataset.
1 parent 52a476a commit dda11b6

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package org.apache.sysds.test.functions.builtin.part1;
2+
3+
import org.apache.sysds.common.Types;
4+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
5+
import org.apache.sysds.test.AutomatedTestBase;
6+
import org.apache.sysds.test.TestConfiguration;
7+
import org.apache.sysds.test.TestUtils;
8+
import org.junit.Test;
9+
10+
import java.util.HashMap;
11+
12+
public class BuiltinCooccurrenceMatrixTest extends AutomatedTestBase {
13+
14+
private static final String TEST_NAME = "cooccurrenceMatrix";
15+
private static final String TEST_DIR = "functions/builtin/";
16+
private static final String RESOURCE_DIRECTORY = "src/test/resources/datasets/";
17+
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinCooccurrenceMatrixTest.class.getSimpleName() + "/";
18+
private static final double EPSILON = 1e-10; // Tolerance for comparison
19+
20+
@Override
21+
public void setUp() {
22+
addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"TestResult",}));
23+
}
24+
25+
@Test
26+
public void cooccurrenceMatrixTest() {
27+
runCooccurrenceMatrix(20, 2, "FALSE", "TRUE");
28+
HashMap<MatrixValue.CellIndex, Double> cooccurrenceMatrix = readDMLMatrixFromOutputDir("TestResult");
29+
double[][] computedC = TestUtils.convertHashMapToDoubleArray(cooccurrenceMatrix);
30+
31+
// Unique words: {apple, banana, orange, grape}
32+
// Co-occurrence based on word pairs in same sentences
33+
double[][] expectedC = new double[][] {
34+
{0, 1, 2, 0}, // apple with {banana, orange}
35+
{1, 0, 3, 1}, // banana with {apple, orange, grape}
36+
{2, 3, 0, 2}, // orange with {apple, banana, grape}
37+
{0, 1, 2, 0} // grape with {banana, orange, grape}
38+
};
39+
40+
TestUtils.compareMatrices(expectedC, computedC, expectedC.length, expectedC[0].length, EPSILON);
41+
42+
}
43+
44+
public void runCooccurrenceMatrix(Integer maxTokens, Integer windowSize, String distanceWeighting, String symmetric) {
45+
// Load test configuration
46+
Types.ExecMode platformOld = setExecMode(Types.ExecType.CP);
47+
try{
48+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
49+
50+
String HOME = SCRIPT_DIR + TEST_DIR;
51+
52+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
53+
54+
programArgs = new String[]{"-nvargs",
55+
"input=" + RESOURCE_DIRECTORY + "GloVe/coocMatrixTest.csv",
56+
"maxTokens=" + maxTokens,
57+
"windowSize=" + windowSize,
58+
"distanceWeighting=" + distanceWeighting,
59+
"symmetric=" + symmetric,
60+
"out_file=" + output("TestResult")};
61+
System.out.println("Run dml script..");
62+
runTest(true, false, null, -1);
63+
System.out.println("DONE");
64+
}
65+
finally {
66+
rtplatform = platformOld;
67+
}
68+
}
69+
70+
71+
}

0 commit comments

Comments
 (0)