Skip to content

Commit f4d9faf

Browse files
authored
Upgrade tf2 for NCF models
1 parent 9890c7a commit f4d9faf

File tree

4 files changed

+17
-18
lines changed

4 files changed

+17
-18
lines changed

cornac/models/ncf/ops.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16+
import warnings
1617

17-
import tensorflow as tf
18+
# disable annoying tensorflow deprecated API warnings
19+
warnings.filterwarnings("ignore", category=UserWarning)
20+
21+
import tensorflow.compat.v1 as tf
22+
23+
tf.logging.set_verbosity(tf.logging.ERROR)
24+
tf.disable_v2_behavior()
1825

1926

2027
act_functions = {
@@ -60,14 +67,14 @@ def emb(
6067
shape=[num_users, emb_size],
6168
dtype=tf.float32,
6269
initializer=tf.random_normal_initializer(stddev=0.01, seed=seed),
63-
regularizer=tf.contrib.layers.l2_regularizer(scale=reg_user),
70+
regularizer=tf.keras.regularizers.L2(reg_user),
6471
)
6572
item_emb = tf.get_variable(
6673
"item_emb",
6774
shape=[num_items, emb_size],
6875
dtype=tf.float32,
6976
initializer=tf.random_normal_initializer(stddev=0.01, seed=seed),
70-
regularizer=tf.contrib.layers.l2_regularizer(scale=reg_item),
77+
regularizer=tf.keras.regularizers.L2(reg_item),
7178
)
7279

7380
return tf.nn.embedding_lookup(user_emb, uid), tf.nn.embedding_lookup(item_emb, iid)
@@ -96,7 +103,7 @@ def mlp(uid, iid, num_users, num_items, layers, reg_layers, act_fn, seed=None):
96103
iid=iid,
97104
num_users=num_users,
98105
num_items=num_items,
99-
emb_size=layers[0] / 2,
106+
emb_size=int(layers[0] / 2),
100107
reg_user=reg_layers[0],
101108
reg_item=reg_layers[0],
102109
seed=seed,
@@ -110,7 +117,6 @@ def mlp(uid, iid, num_users, num_items, layers, reg_layers, act_fn, seed=None):
110117
name="layer{}".format(i + 1),
111118
activation=act_functions.get(act_fn, tf.nn.relu),
112119
kernel_initializer=tf.initializers.lecun_uniform(seed),
113-
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_layers[i + 1]),
120+
kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]),
114121
)
115122
return interaction
116-

cornac/models/ncf/recom_gmf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16+
1617
import numpy as np
1718

1819
from .recom_ncf_base import NCFBase

cornac/models/ncf/recom_ncf_base.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16-
import os
17-
import copy
16+
1817
from tqdm.auto import trange
1918

2019
from ..recommender import Recommender
21-
from ...exception import ScoreException
2220
from ...utils import get_rng
2321

2422

@@ -132,10 +130,6 @@ def fit(self, train_set, val_set=None):
132130
def _build_graph(self):
133131
import tensorflow.compat.v1 as tf
134132

135-
# less verbose TF
136-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
137-
tf.logging.set_verbosity(tf.logging.ERROR)
138-
139133
self.graph = tf.Graph()
140134

141135
def _sess_init(self):
@@ -158,8 +152,6 @@ def _step_update(self, batch_users, batch_items, batch_ratings):
158152
return _loss
159153

160154
def _fit_tf(self):
161-
import tensorflow.compat.v1 as tf
162-
163155
loop = trange(self.num_epochs, disable=not self.verbose)
164156
for _ in loop:
165157
count = 0
@@ -210,9 +202,9 @@ def load(model_path, trainable=False):
210202
provided, the latest model will be loaded.
211203
212204
trainable: boolean, optional, default: False
213-
Set it to True if you would like to finetune the model. By default,
205+
Set it to True if you would like to finetune the model. By default,
214206
the model parameters are assumed to be fixed after being loaded.
215-
207+
216208
Returns
217209
-------
218210
self : object

cornac/models/ncf/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
tensorflow>=1.15.2,<2.0.0
1+
tensorflow==2.12.0

0 commit comments

Comments
 (0)