Skip to content

Commit 96e93b8

Browse files
authored
Fix ConvMF loss info (#349)
1 parent 54e870d commit 96e93b8

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

cornac/models/conv_mf/recom_convmf.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ class ConvMF(Recommender):
3838
3939
cnn_epochs: int, optional, default: 5
4040
Number of epochs for optimizing the CNN for each overall training epoch.
41+
42+
cnn_bs: int, optional, default: 128
43+
Batch size for optimizing CNN.
44+
45+
cnn_lr: float, optional, default: 0.001
46+
Learning rate for optimizing CNN.
4147
4248
lambda_u: float, optional, default: 1.0
4349
The regularization hyper-parameter for user latent factor.
@@ -85,6 +91,8 @@ def __init__(
8591
k=50,
8692
n_epochs=50,
8793
cnn_epochs=5,
94+
cnn_bs=128,
95+
cnn_lr=128,
8896
lambda_u=1,
8997
lambda_v=100,
9098
emb_dim=200,
@@ -102,6 +110,8 @@ def __init__(
102110
super().__init__(name=name, trainable=trainable, verbose=verbose)
103111
self.give_item_weight = give_item_weight
104112
self.n_epochs = n_epochs
113+
self.cnn_bs = cnn_bs
114+
self.cnn_lr = cnn_lr
105115
self.lambda_u = lambda_u
106116
self.lambda_v = lambda_v
107117
self.k = k
@@ -191,7 +201,7 @@ def _fit_convmf(self):
191201
# Initialize cnn module
192202
import tensorflow.compat.v1 as tf
193203
from .convmf import CNN_module
194-
204+
195205
# less verbose TF
196206
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
197207
tf.logging.set_verbosity(tf.logging.ERROR)
@@ -207,6 +217,7 @@ def _fit_convmf(self):
207217
hidden_dim=self.hidden_dim,
208218
seed=self.seed,
209219
init_W=self.W,
220+
learning_rate=self.cnn_lr,
210221
)
211222

212223
config = tf.ConfigProto()
@@ -230,10 +241,10 @@ def _fit_convmf(self):
230241
for epoch in range(1, self.n_epochs + 1):
231242
if self.verbose:
232243
print("Epoch: {}/{}".format(epoch, self.n_epochs))
233-
244+
234245
tic = time.time()
235246

236-
user_loss = np.zeros(n_user)
247+
user_loss = 0.0
237248
for i in range(n_user):
238249
idx_item = user_data[0][i]
239250
V_i = self.V[idx_item]
@@ -243,9 +254,9 @@ def _fit_convmf(self):
243254
B = (V_i * (np.tile(R_i, (self.k, 1)).T)).sum(0)
244255
self.U[i] = np.linalg.solve(A, B)
245256

246-
user_loss[i] = -0.5 * self.lambda_u * np.dot(self.U[i], self.U[i])
257+
user_loss += self.lambda_u * np.dot(self.U[i], self.U[i])
247258

248-
item_loss = np.zeros(n_item)
259+
item_loss = 0.0
249260
for j in range(n_item):
250261
idx_user = item_data[0][j]
251262
U_j = self.U[idx_user]
@@ -257,11 +268,15 @@ def _fit_convmf(self):
257268
) + self.lambda_v * item_weight[j] * theta[j]
258269
self.V[j] = np.linalg.solve(A, B)
259270

260-
item_loss[j] = -np.square(R_j - U_j.dot(self.V[j])).sum()
271+
item_loss += np.square(R_j - U_j.dot(self.V[j])).sum()
261272

262-
loop = trange(self.cnn_epochs, desc="Optimizing CNN", disable=not self.verbose)
273+
loop = trange(
274+
self.cnn_epochs, desc="Optimizing CNN", disable=not self.verbose
275+
)
263276
for _ in loop:
264-
for batch_ids in self.train_set.item_iter(batch_size=128, shuffle=True):
277+
for batch_ids in self.train_set.item_iter(
278+
batch_size=self.cnn_bs, shuffle=True
279+
):
265280
batch_seq = self.train_set.item_text.batch_seq(
266281
batch_ids, max_length=self.max_len
267282
)
@@ -282,21 +297,18 @@ def _fit_convmf(self):
282297
[cnn_module.model_output, cnn_module.weighted_loss], feed_dict=feed_dict
283298
)
284299

285-
loss = (
286-
loss
287-
+ np.sum(user_loss)
288-
+ np.sum(item_loss)
289-
- 0.5 * self.lambda_v * cnn_loss
290-
)
300+
loss = 0.5 * (user_loss + item_loss + self.lambda_v * cnn_loss)
301+
291302
toc = time.time()
292303
elapsed = toc - tic
293304
converge = abs((loss - history) / history)
294-
305+
295306
if self.verbose:
296307
print(
297-
"Loss: %.5f Elpased: %.4fs Converge: %.6f " % (loss, elapsed, converge)
308+
"Loss: %.5f Elapsed: %.4fs Converge: %.6f "
309+
% (loss, elapsed, converge)
298310
)
299-
311+
300312
history = loss
301313
if converge < converge_threshold:
302314
endure -= 1

0 commit comments

Comments
 (0)