Skip to content

Commit d94ed5a

Browse files
authored
fix: NeuMF act_fn and num_factors parameter (#683)
1 parent aa75197 commit d94ed5a

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

cornac/models/ncf/backend_pt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn as nn
33

4-
54
optimizer_dict = {
65
"sgd": torch.optim.SGD,
76
"adam": torch.optim.Adam,
@@ -16,7 +15,7 @@
1615
"selu": nn.SELU(),
1716
"relu": nn.ReLU(),
1817
"relu6": nn.ReLU6(),
19-
"leakyrelu": nn.LeakyReLU(),
18+
"leaky_relu": nn.LeakyReLU(),
2019
}
2120

2221

cornac/models/ncf/recom_neumf.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
import numpy as np
1717

18-
from .recom_ncf_base import NCFBase
1918
from ...exception import ScoreException
19+
from .recom_ncf_base import NCFBase
2020

2121

2222
class NeuMF(NCFBase):
@@ -59,13 +59,13 @@ class NeuMF(NCFBase):
5959
6060
backend: str, optional, default: 'tensorflow'
6161
Backend used for model training: tensorflow, pytorch
62-
62+
6363
early_stopping: {min_delta: float, patience: int}, optional, default: None
64-
If `None`, no early stopping. Meaning of the arguments:
65-
64+
If `None`, no early stopping. Meaning of the arguments:
65+
6666
- `min_delta`: the minimum increase in monitored value on validation set to be considered as improvement, \
6767
i.e. an increment of less than min_delta will count as no improvement.
68-
68+
6969
- `patience`: number of epochs with no improvement after which training should be stopped.
7070
7171
name: string, optional, default: 'NeuMF'
@@ -159,12 +159,13 @@ def from_pretrained(self, pretrained_gmf, pretrained_mlp, alpha=0.5):
159159
########################
160160
def _build_model_tf(self):
161161
import tensorflow as tf
162+
162163
from .backend_tf import GMFLayer, MLPLayer
163-
164+
164165
# Define inputs
165166
user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="user_input")
166167
item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="item_input")
167-
168+
168169
# GMF layer
169170
gmf_layer = GMFLayer(
170171
num_users=self.num_users,
@@ -175,7 +176,7 @@ def _build_model_tf(self):
175176
seed=self.seed,
176177
name="gmf_layer"
177178
)
178-
179+
179180
# MLP layer
180181
mlp_layer = MLPLayer(
181182
num_users=self.num_users,
@@ -186,72 +187,72 @@ def _build_model_tf(self):
186187
seed=self.seed,
187188
name="mlp_layer"
188189
)
189-
190+
190191
# Get embeddings and element-wise product
191192
gmf_vector = gmf_layer([user_input, item_input])
192193
mlp_vector = mlp_layer([user_input, item_input])
193-
194+
194195
# Concatenate GMF and MLP vectors
195196
concat_vector = tf.keras.layers.Concatenate(axis=-1)([gmf_vector, mlp_vector])
196-
197+
197198
# Output layer
198199
logits = tf.keras.layers.Dense(
199200
1,
200201
kernel_initializer=tf.keras.initializers.LecunUniform(seed=self.seed),
201202
name="logits"
202203
)(concat_vector)
203-
204+
204205
prediction = tf.keras.layers.Activation('sigmoid', name="prediction")(logits)
205-
206+
206207
# Create model
207208
model = tf.keras.Model(
208209
inputs=[user_input, item_input],
209210
outputs=prediction,
210211
name="NeuMF"
211212
)
212-
213+
213214
# Handle pretrained models
214215
if self.pretrained:
215216
# Get GMF and MLP models
216217
gmf_model = self.pretrained_gmf.model
217218
mlp_model = self.pretrained_mlp.model
218-
219+
219220
# Copy GMF embeddings
220221
model.get_layer('gmf_layer').user_embedding.set_weights(
221222
gmf_model.get_layer('gmf_layer').user_embedding.get_weights()
222223
)
223224
model.get_layer('gmf_layer').item_embedding.set_weights(
224225
gmf_model.get_layer('gmf_layer').item_embedding.get_weights()
225226
)
226-
227+
227228
# Copy MLP embeddings and layers
228229
model.get_layer('mlp_layer').user_embedding.set_weights(
229230
mlp_model.get_layer('mlp_layer').user_embedding.get_weights()
230231
)
231232
model.get_layer('mlp_layer').item_embedding.set_weights(
232233
mlp_model.get_layer('mlp_layer').item_embedding.get_weights()
233234
)
234-
235+
235236
# Copy dense layers in MLP
236237
for i, layer in enumerate(model.get_layer('mlp_layer').dense_layers):
237238
layer.set_weights(mlp_model.get_layer('mlp_layer').dense_layers[i].get_weights())
238-
239+
239240
# Combine weights for output layer
240241
gmf_logits_weights = gmf_model.get_layer('logits').get_weights()
241242
mlp_logits_weights = mlp_model.get_layer('logits').get_weights()
242-
243+
243244
# Combine kernel weights
244245
combined_kernel = np.concatenate([
245246
self.alpha * gmf_logits_weights[0],
246247
(1.0 - self.alpha) * mlp_logits_weights[0]
247248
], axis=0)
248-
249+
249250
# Combine bias weights
250251
combined_bias = self.alpha * gmf_logits_weights[1] + (1.0 - self.alpha) * mlp_logits_weights[1]
251-
252+
252253
# Set combined weights to output layer
253254
model.get_layer('logits').set_weights([combined_kernel, combined_bias])
254-
255+
255256
return model
256257

257258
#####################
@@ -264,6 +265,7 @@ def _build_model_pt(self):
264265
num_users=self.num_users,
265266
num_items=self.num_items,
266267
layers=self.layers,
268+
num_factors=self.num_factors,
267269
act_fn=self.act_fn,
268270
)
269271
if self.pretrained:

0 commit comments

Comments
 (0)