Skip to content

Commit 6fd08c0

Browse files
saminbassirimboehm7
authored andcommitted
[SYSTEMDS-3179] Builtin for GloVe cooccurrence matrix computation
Closes #2200.
1 parent 5f390b3 commit 6fd08c0

File tree

5 files changed

+295
-0
lines changed

5 files changed

+295
-0
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
#
22+
# The implementation is based on
23+
# https://github.com/stanfordnlp/GloVe/blob/master/src/cooccur.c
24+
#
25+
#-------------------------------------------------------------
26+
27+
## Cleans and processes text data by removing punctuation, converting it to lowercase, and reformatting.
28+
## Adds an index column to the result.
29+
# INPUT:
30+
# ------------------------------------------------------------------------------
31+
# S (Frame[Unknown]): 1D input data frame containing text data.
32+
# ------------------------------------------------------------------------------
33+
# OUTPUT:
34+
# ------------------------------------------------------------------------------
35+
# result (Frame[Unknown]): Processed text data with an index column.
36+
# ------------------------------------------------------------------------------
37+
processText = function(Frame[Unknown] S) return (Frame[Unknown] result){
38+
print("processText");
39+
tmpStr = map(S[,1], "x -> x.replaceAll(\"[.]\", \"\")");
40+
tmpStr = map(tmpStr, "x -> x.replaceAll(\"[^a-zA-Z\\s]\", \" \")");
41+
tmpStr = map(tmpStr, "x -> x.toLowerCase()");
42+
result = cbind(as.frame(seq(1, nrow(S), 1)), tmpStr);
43+
}
44+
45+
## Tokenizes text data and retrieves word positions.
46+
# INPUT:
47+
# ------------------------------------------------------------------------------
48+
# S (Frame[Unknown]): 2D input text data with an index column.
49+
# maxTokens (Int): Maximum number of tokens per text entry.
50+
# ------------------------------------------------------------------------------
51+
# OUTPUT:
52+
# ------------------------------------------------------------------------------
53+
# result (Frame[Unknown]): Tokenized words.
54+
# docID (Matrix[double]): Document ID matrix corresponding to tokens.
55+
# ------------------------------------------------------------------------------
56+
getWordPosition = function(Frame[Unknown] S, Int maxTokens) return (Frame[Unknown] result, Matrix[double] docID){
57+
print("getWordPosition");
58+
jspec_pos = "{\"algo\": \"split\", \"out\": \"position\",\"out_params\": {\"sort_alpha\": false},\"id_cols\": [1],\"tokenize_col\": 2}";
59+
wordPosition = tokenize(target=S, spec=jspec_pos, max_tokens=maxTokens);
60+
result = wordPosition[,3];
61+
docID = as.matrix(wordPosition[,1]);
62+
}
63+
64+
## Encodes words into a numerical matrix format, retrieves the vocabulary size, and maps word indices.
65+
## Uses transformencode() to recode strings and find each unique string position in the co-occurrence matrix.
66+
# INPUT:
67+
# ------------------------------------------------------------------------------
68+
# S (Frame[Unknown]): 1D frame of tokenized word positions.
69+
# ------------------------------------------------------------------------------
70+
# OUTPUT:
71+
# ------------------------------------------------------------------------------
72+
# recodedWordPosition (Matrix[double]): Encoded word positions as a numerical matrix.
73+
# tableSize (Int): Number of distinct words in the input text (co-occurrence matrix size).
74+
# column (Frame[Unknown]): Mapping of word indices to distinct words in the co-occurrence matrix.
75+
# ------------------------------------------------------------------------------
76+
getRecodedMatrix = function(Frame[Unknown] S) return (Matrix[double] recodedWordPosition, Int tableSize, Frame[Unknown] column){
77+
print("getRecodedMatrix");
78+
[recodedWordPosition, M] = transformencode(target=S, spec="{ids:true,recode:[1]}");
79+
80+
distinctWord = map(M[,1], "s -> UtilFunctions.splitRecodeEntry(s)[0]");
81+
index = map(M[,1], "s -> Integer.valueOf(UtilFunctions.splitRecodeEntry(s)[1])");
82+
column = cbind(index, distinctWord);
83+
sortedIndex = order(target=as.matrix(index), by=1, decreasing=FALSE, index.return=TRUE);
84+
85+
#TODO vectorize via order of frames
86+
for(i in 1:nrow(sortedIndex)){
87+
p = as.integer(as.scalar(sortedIndex[i,1]));
88+
column[as.integer(as.scalar(index[p])), 2] = distinctWord[p];
89+
column[i, 1] = as.scalar(index[p]);
90+
}
91+
tableSize = nrow(distinctWord);
92+
}
93+
94+
## Iterates over the recoded word positions to construct a co-occurrence matrix.
95+
# INPUT:
96+
# ------------------------------------------------------------------------------
97+
# recodedWordPosition (Matrix[double]): 2D matrix of recoded word positions with text IDs.
98+
# tableSize (Int): Size of the vocabulary (number of unique words).
99+
# distanceWeighting (Boolean): Flag to apply distance weighting to co-occurrence counts.
100+
# symmetric (Boolean): Determines if the matrix is symmetric (TRUE) or asymmetric (FALSE).
101+
# windowSize (Int): Context window size.
102+
# ------------------------------------------------------------------------------
103+
# OUTPUT:
104+
# ------------------------------------------------------------------------------
105+
# coocMatrix (Matrix[double]): Final word-word co-occurrence matrix.
106+
# ------------------------------------------------------------------------------
107+
createCoocMatrix = function(
108+
Matrix[double] recodedWordPosition,
109+
Int tableSize,
110+
boolean distanceWeighting,
111+
boolean symmetric,
112+
Int windowSize)
113+
return (Matrix[double] coocMatrix)
114+
{
115+
print("Processing word cooccurrence...");
116+
coocMatrix = matrix(0, tableSize, tableSize);
117+
118+
#TODO vectorize loop
119+
for (i in 1:nrow(recodedWordPosition)) {
120+
docId = as.integer(as.scalar(recodedWordPosition[i,1]));
121+
wordIndex = as.integer(as.scalar(recodedWordPosition[i,2]));
122+
if(wordIndex != 0){# This check is due to wrong result of the transformencode when running jvm test.
123+
for (j in 1:windowSize) {
124+
# Check left context
125+
if (i-j > 0) {
126+
if(docId == as.integer(as.scalar(recodedWordPosition[i-j, 1])))
127+
{
128+
neighbourWordIndex = as.integer(as.scalar(recodedWordPosition[i-j,2]));
129+
increase = ifelse(distanceWeighting, 1.0 / j, 1.0);
130+
coocMatrix[wordIndex, neighbourWordIndex] = coocMatrix[wordIndex, neighbourWordIndex] + increase;
131+
}
132+
}
133+
# Check right context if symmetric
134+
if(symmetric == TRUE){
135+
if (i+j < nrow(recodedWordPosition) + 1) {
136+
if(docId == as.integer(as.scalar(recodedWordPosition[i+j, 1])))
137+
{
138+
neighbourWordIndex = as.integer(as.scalar(recodedWordPosition[i+j,2]));
139+
increase = ifelse(distanceWeighting, 1.0 / j, 1.0);
140+
coocMatrix[wordIndex, neighbourWordIndex] = coocMatrix[wordIndex, neighbourWordIndex] + increase;
141+
}
142+
}
143+
}
144+
}
145+
}
146+
}
147+
print("Word-word cooccurrence matrix computation completed.");
148+
}
149+
150+
## Main function to process text data to construct a word-word co-occurrence matrix.
151+
# INPUT:
152+
# ------------------------------------------------------------------------------
153+
# input (Frame[Unknown]): 1DInput corpus in CSV format.
154+
# maxTokens (Int): Maximum number of tokens per text entry.
155+
# windowSize (Int): Context window size.
156+
# distanceWeighting (Boolean): Whether to apply distance-based weighting.
157+
# symmetric (Boolean): Determines if the matrix is symmetric (TRUE) or asymmetric (FALSE).
158+
# ------------------------------------------------------------------------------
159+
# OUTPUT:
160+
# ------------------------------------------------------------------------------
161+
# coocMatrix (Matrix[double]): The computed co-occurrence matrix.
162+
# column (Frame[Unknown]): Word-index mapping for the co-occurrence matrix.
163+
# ------------------------------------------------------------------------------
164+
f_cooccurrenceMatrix = function(
165+
Frame[Unknown] input,
166+
Int maxTokens,
167+
Int windowSize,
168+
Boolean distanceWeighting,
169+
Boolean symmetric) return (Matrix[Double] coocMatrix, Frame[Unknown] column){
170+
171+
processedResult = processText(input);
172+
[wordPosition, docID] = getWordPosition(processedResult, maxTokens);
173+
[recodedWordPosition, tableSize, column] = getRecodedMatrix(wordPosition);
174+
coocMatrix = createCoocMatrix(cbind(docID, recodedWordPosition), tableSize, distanceWeighting, symmetric, windowSize);
175+
}

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ public enum Builtins {
9393
CONV2D("conv2d", false),
9494
CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false),
9595
CONV2D_BACKWARD_DATA("conv2d_backward_data", false),
96+
COOCCURRENCEMATRIX("cooccurrenceMatrix", true),
9697
COR("cor", true),
9798
CORRECTTYPOS("correctTypos", true),
9899
CORRECTTYPOSAPPLY("correctTyposApply", true),
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part1;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
24+
import org.apache.sysds.test.AutomatedTestBase;
25+
import org.apache.sysds.test.TestConfiguration;
26+
import org.apache.sysds.test.TestUtils;
27+
import org.junit.Test;
28+
29+
import java.util.HashMap;
30+
31+
public class BuiltinCooccurrenceMatrixTest extends AutomatedTestBase {
32+
33+
private static final String TEST_NAME = "cooccurrenceMatrix";
34+
private static final String TEST_DIR = "functions/builtin/";
35+
private static final String RESOURCE_DIRECTORY = "src/test/resources/datasets/";
36+
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinCooccurrenceMatrixTest.class.getSimpleName() + "/";
37+
private static final double EPSILON = 1e-10; // Tolerance for comparison
38+
39+
@Override
40+
public void setUp() {
41+
addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"TestResult",}));
42+
}
43+
44+
@Test
45+
public void cooccurrenceMatrixTest() {
46+
runCooccurrenceMatrix(20, 2, "FALSE", "TRUE");
47+
HashMap<MatrixValue.CellIndex, Double> cooccurrenceMatrix = readDMLMatrixFromOutputDir("TestResult");
48+
double[][] computedC = TestUtils.convertHashMapToDoubleArray(cooccurrenceMatrix);
49+
50+
// Unique words: {apple, banana, orange, grape}
51+
// Co-occurrence based on word pairs in same sentences
52+
double[][] expectedC = new double[][] {
53+
{0, 1, 2, 0}, // apple with {banana, orange}
54+
{1, 0, 3, 1}, // banana with {apple, orange, grape}
55+
{2, 3, 0, 2}, // orange with {apple, banana, grape}
56+
{0, 1, 2, 0} // grape with {banana, orange, grape}
57+
};
58+
59+
TestUtils.compareMatrices(expectedC, computedC, expectedC.length, expectedC[0].length, EPSILON);
60+
61+
}
62+
63+
public void runCooccurrenceMatrix(Integer maxTokens, Integer windowSize, String distanceWeighting, String symmetric) {
64+
// Load test configuration
65+
Types.ExecMode platformOld = setExecMode(Types.ExecType.CP);
66+
try{
67+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
68+
69+
String HOME = SCRIPT_DIR + TEST_DIR;
70+
71+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
72+
73+
programArgs = new String[]{"-nvargs",
74+
"input=" + RESOURCE_DIRECTORY + "GloVe/coocMatrixTest.csv",
75+
"maxTokens=" + maxTokens,
76+
"windowSize=" + windowSize,
77+
"distanceWeighting=" + distanceWeighting,
78+
"symmetric=" + symmetric,
79+
"out_file=" + output("TestResult")};
80+
System.out.println("Run dml script..");
81+
runTest(true, false, null, -1);
82+
System.out.println("DONE");
83+
}
84+
finally {
85+
rtplatform = platformOld;
86+
}
87+
}
88+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
apple banana orange.
2+
banana orange grape.
3+
apple. orange
4+
grape 1111 ------ orange.
5+
------ <<<<<<< 1111 22222.
6+
banana orange
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
input = read($input, data_type="frame", format="csv", sep=",", header=FALSE);
23+
24+
[coocMatrix, column] = cooccurrenceMatrix(input, $maxTokens, $windowSize, $distanceWeighting, $symmetric);
25+
write(coocMatrix, $out_file , data_type="matrix");

0 commit comments

Comments
 (0)