Skip to content

Commit 467c553

Browse files
xixuanzhang2022saminbassiri
authored andcommitted
[SYSTEMDS-3179] Builtin for GloVe word embedding training
Closes #2201. Co-authored-by: Samin <[email protected]>
1 parent 6fd08c0 commit 467c553

File tree

5 files changed

+867
-0
lines changed

5 files changed

+867
-0
lines changed

scripts/builtin/glove.dml

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#-------------------------------------------------------------
20+
21+
init = function(matrix[double] cooc_matrix, double x_max, double alpha)
22+
return(matrix[double] weights, matrix[double] log_cooc_matrix){
23+
E = 2.718281828;
24+
bounded = pmin(cooc_matrix, x_max);
25+
weights = pmin(1, (bounded / x_max) ^ alpha);
26+
log_cooc_matrix = ifelse(cooc_matrix > 0, log(cooc_matrix, E), 0);
27+
}
28+
29+
gloveWithCoocMatrix = function(matrix[double] cooc_matrix, frame[Unknown] cooc_index, int seed, int vector_size, double alpha, double eta, double x_max, double tol, int iterations,int print_loss_it)
30+
return (frame[Unknown] G){
31+
/*
32+
* Computes the vector embeddings for words by analyzing their co-occurrence statistics in a large text corpus.
33+
*
34+
* Inputs:
35+
* - cooc_matrix: Precomputed co-occurrence matrix of shape (N, N).
36+
* - cooc_index: Index file mapping words to their positions in the co-occurrence matrix.
37+
* The second column should contain the word list in the same order as the matrix.
38+
* - seed: Random seed for reproducibility.
39+
* - vector_size: Dimensionality of word vectors, V.
40+
* - eta: Learning rate for optimization, recommended value: 0.05.
41+
* - alpha: Weighting function parameter, recommended value: 0.75.
42+
* - x_max: Maximum co-occurrence value as per the GloVe paper: 100.
43+
* - tol: Tolerance value to avoid overfitting, recommended value: 1e-4.
44+
* - iterations: Total number of training iterations.
45+
* - print_loss_it: Interval (in iterations) for printing the loss.
46+
*
47+
* Outputs:
48+
* - G: frame of the word indices and their word vectors, of shape (N, V). Each represented as a vector, of shape (1,V)
49+
*/
50+
51+
vocab_size = nrow(cooc_matrix);
52+
W = (rand(rows=vocab_size, cols=vector_size, min=0, max=1, seed=seed)-0.5)/vector_size;
53+
C = (rand(rows=vocab_size, cols=vector_size, min=0, max=1, seed=seed+1)-0.5)/vector_size;
54+
bw = (rand(rows=vocab_size, cols=1, min=0, max=1, seed=seed+2)-0.5)/vector_size;
55+
bc = (rand(rows=vocab_size, cols=1, min=0, max=1, seed=seed+3)-0.5)/vector_size;
56+
[weights, log_cooc_matrix] = init(cooc_matrix, x_max, alpha);
57+
58+
momentum_W = 1e-8 + 0.1 * matrix(1, nrow(W), ncol(W));
59+
momentum_C = 1e-8 + 0.1 * matrix(1, nrow(C), ncol(C));
60+
momentum_bw = 1e-8 + 0.1 * matrix(1, nrow(bw), ncol(bw));
61+
momentum_bc = 1e-8 + 0.1 * matrix(1, nrow(bc), ncol(bc));
62+
63+
error = 0;
64+
iter = 0;
65+
tolerance = tol;
66+
previous_error = 1e10;
67+
conti = TRUE;
68+
69+
while (conti) {
70+
71+
# compute predictions for all co-occurring word pairs at once
72+
predictions = W %*% t(C) + bw + t(bc);
73+
diffs = predictions - log_cooc_matrix;
74+
weighted_diffs = weights * diffs;
75+
76+
# compute gradients
77+
wgrad = weighted_diffs %*% C;
78+
cgrad = t(weighted_diffs) %*% W;
79+
bwgrad = rowSums(weighted_diffs);
80+
bcgrad = matrix(colSums(weighted_diffs), nrow(bc), ncol(bc));
81+
82+
error = sum(0.5 * (weights * (diffs ^ 2)));
83+
iter = iter + 1;
84+
85+
86+
if (abs(previous_error - error) >= tolerance) {
87+
if(iter <= iterations){
88+
89+
# get steps and update
90+
momentum_W = momentum_W + (wgrad ^ 2);
91+
momentum_C = momentum_C + (cgrad ^ 2);
92+
momentum_bw = momentum_bw + (bwgrad ^ 2);
93+
momentum_bc = momentum_bc + (bcgrad ^ 2);
94+
95+
W = W - (eta * wgrad / (sqrt(momentum_W) + 1e-8));
96+
C = C - (eta * cgrad / (sqrt(momentum_C) + 1e-8));
97+
bw = bw - (eta * bwgrad / (sqrt(momentum_bw) + 1e-8));
98+
bc = bc - (eta * bcgrad / (sqrt(momentum_bc) + 1e-8));
99+
100+
G = W + C;
101+
102+
previous_error = error;
103+
104+
final_iter = iter;
105+
} else {
106+
conti = FALSE;
107+
}
108+
} else {
109+
conti = FALSE;
110+
}
111+
112+
if (iter - floor(iter / print_loss_it) * print_loss_it == 0) {
113+
print("iteration: " + iter + " error: " + error);
114+
}
115+
}
116+
117+
# add the word index to the word vectors
118+
print("Given " + iterations + " iterations, " + "stopped (or converged) at the " + final_iter + " iteration / error: " + error);
119+
G = cbind(cooc_index[,2], as.frame(G));
120+
}
121+
122+
glove = function(
123+
Frame[Unknown] input,
124+
int seed, int vector_size,
125+
double alpha, double eta,
126+
double x_max,
127+
double tol,
128+
int iterations,
129+
int print_loss_it,
130+
Int maxTokens,
131+
Int windowSize,
132+
Boolean distanceWeighting,
133+
Boolean symmetric)
134+
return (frame[Unknown] G){
135+
136+
/*
137+
* Main function to Computes the vector embeddings for words in a large text corpus.
138+
* INPUT:
139+
* ------------------------------------------------------------------------------
140+
* - input (Frame[Unknown]): 1DInput corpus in CSV format.
141+
* - seed: Random seed for reproducibility.
142+
* - vector_size: Dimensionality of word vectors, V.
143+
* - eta: Learning rate for optimization, recommended value: 0.05.
144+
* - alpha: Weighting function parameter, recommended value: 0.75.
145+
* - x_max: Maximum co-occurrence value as per the GloVe paper: 100.
146+
* - tol: Tolerance value to avoid overfitting, recommended value: 1e-4.
147+
* - iterations: Total number of training iterations.
148+
* - print_loss_it: Interval (in iterations) for printing the loss.
149+
* - maxTokens (Int): Maximum number of tokens per text entry.
150+
* - windowSize (Int): Context window size.
151+
* - distanceWeighting (Boolean): Whether to apply distance-based weighting.
152+
* - symmetric (Boolean): Determines if the matrix is symmetric (TRUE) or asymmetric (FALSE).
153+
* ------------------------------------------------------------------------------
154+
* OUTPUT:
155+
* ------------------------------------------------------------------------------
156+
* G (Frame[Unknown]): The word indices and their word vectors, of shape (N, V). Each represented as a vector, of shape (1,V)
157+
* ------------------------------------------------------------------------------
158+
*/
159+
160+
[cooc_matrix, cooc_index] = cooccurrenceMatrix(input, maxTokens, windowSize, distanceWeighting, symmetric);
161+
G = gloveWithCoocMatrix(cooc_matrix, cooc_index, seed, vector_size, alpha, eta, x_max, tol, iterations, print_loss_it);
162+
}

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ public enum Builtins {
154154
GET_ACCURACY("getAccuracy", true),
155155
GLM("glm", true),
156156
GLM_PREDICT("glmPredict", true),
157+
GLOVE("glove", true),
157158
GMM("gmm", true),
158159
GMM_PREDICT("gmmPredict", true),
159160
GNMF("gnmf", true),
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part1;
21+
22+
import java.io.IOException;
23+
import java.util.Objects;
24+
25+
import org.apache.sysds.common.Types;
26+
import org.apache.sysds.common.Types.FileFormat;
27+
import org.apache.sysds.runtime.frame.data.FrameBlock;
28+
import org.apache.sysds.test.AutomatedTestBase;
29+
import org.apache.sysds.test.TestConfiguration;
30+
import org.junit.Test;
31+
32+
public class BuiltinGloVeTest extends AutomatedTestBase {
33+
34+
private static final String TEST_NAME = "glove";
35+
private static final String TEST_DIR = "functions/builtin/";
36+
private static final String RESOURCE_DIRECTORY = "./src/test/resources/datasets/";
37+
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGloVeTest.class.getSimpleName() + "/";
38+
39+
private static final int TOP_K = 5;
40+
private static final double ACCURACY_THRESHOLD = 0.85;
41+
42+
private static final double seed = 45;
43+
private static final double vector_size = 100;
44+
private static final double alpha = 0.75;
45+
private static final double eta = 0.05;
46+
private static final double x_max = 100;
47+
private static final double tol = 1e-4;
48+
private static final double iterations = 10000;
49+
private static final double print_loss_it = 100;
50+
private static final double maxTokens = 2600;
51+
private static final double windowSize = 15;
52+
private static final String distanceWeighting = "TRUE";
53+
private static final String symmetric = "TRUE";
54+
55+
@Override
56+
public void setUp() {
57+
addTestConfiguration(TEST_NAME,
58+
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"out_result"}));
59+
}
60+
61+
@Test
62+
public void gloveTest() throws IOException{
63+
// Using top-5 words for similarity comparison
64+
runGloVe(TOP_K);
65+
66+
// Read the computed similarity results from SystemDS
67+
FrameBlock computedSimilarity = readDMLFrameFromHDFS("out_result", FileFormat.CSV);
68+
69+
// Load expected results (precomputed in Python)
70+
FrameBlock expectedSimilarity = readDMLFrameFromHDFS(RESOURCE_DIRECTORY + "/GloVe/gloveExpectedTop10.csv", FileFormat.CSV, false);
71+
72+
// Compute accuracy by comparing computed and expected results
73+
double accuracy = computeAccuracy(computedSimilarity, expectedSimilarity, TOP_K);
74+
75+
System.out.println("Computed Accuracy: " + accuracy);
76+
77+
// Ensure accuracy is above a reasonable threshold
78+
assert accuracy > ACCURACY_THRESHOLD : "Accuracy too low! Expected > 85% match.";
79+
}
80+
81+
public void runGloVe(int topK) {
82+
// Load test configuration
83+
Types.ExecMode platformOld = setExecMode(Types.ExecType.CP);
84+
try {
85+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
86+
87+
String HOME = SCRIPT_DIR + TEST_DIR;
88+
89+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
90+
91+
programArgs = new String[] {
92+
"-nvargs",
93+
"input=" + RESOURCE_DIRECTORY + "20news/20news_subset_untokenized.csv",
94+
"seed=" + seed,
95+
"vector_size=" + vector_size,
96+
"alpha=" + alpha,
97+
"eta=" + eta,
98+
"x_max=" + x_max,
99+
"tol=" + tol,
100+
"iterations=" + iterations,
101+
"print_loss_it=" + print_loss_it,
102+
"maxTokens=" + maxTokens,
103+
"windowSize=" + windowSize,
104+
"distanceWeighting=" + distanceWeighting,
105+
"symmetric=" + symmetric,
106+
"topK=" + topK,
107+
"out_result=" + output("out_result")
108+
};
109+
110+
System.out.println("Running DML script...");
111+
runTest(true, false, null, -1);
112+
System.out.println("Test execution completed.");
113+
} finally {
114+
rtplatform = platformOld;
115+
}
116+
}
117+
118+
/**
119+
* Computes accuracy by comparing top-K similar words between computed and expected results.
120+
*/
121+
private double computeAccuracy(FrameBlock computed, FrameBlock expected, int k) {
122+
int count = 0;
123+
for (int i = 0; i < computed.getNumRows(); i++) {
124+
int matchCount = 0;
125+
for (int j = 1; j < k; j++) {
126+
String word1 = computed.getString(i, j);
127+
for (int m = 0; m < k; m++) {
128+
if (Objects.equals(word1, expected.getString(i, m))) {
129+
matchCount++;
130+
break;
131+
}
132+
}
133+
}
134+
if (matchCount > 0) {
135+
count++;
136+
}
137+
}
138+
return (double) count / computed.getNumRows();
139+
}
140+
}

0 commit comments

Comments
 (0)