Skip to content

Commit 8138cf1

Browse files
authored
Update several models to be compatible with TensorFlow 2.0 (#504)
* Change to tensorflow.compat.v1 for WMF model * Change to tensorflow.compat.v1 for CDL model * Change to tensorflow.compat.v1 for CDR model * Change to tensorflow.compat.v1 for ConvMF model * Fix compatibility issues for CVAE * Update requirements.txt for NCF model * Fix numpy dtype compatibility issue
1 parent 757edb6 commit 8138cf1

File tree

22 files changed

+57
-49
lines changed

22 files changed

+57
-49
lines changed

cornac/data/dataset.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def dok_matrix(self):
272272
"""The user-item interaction matrix in DOK sparse format"""
273273
if self.__dok_matrix is None:
274274
self.__dok_matrix = dok_matrix(
275-
(self.num_users, self.num_items), dtype=np.float32
275+
(self.num_users, self.num_items), dtype='float'
276276
)
277277
for u, i, r in zip(*self.uir_tuple):
278278
self.__dok_matrix[u, i] = r
@@ -364,13 +364,13 @@ def build(
364364
raise ValueError("data is empty after being filtered!")
365365

366366
uir_tuple = (
367-
np.asarray(u_indices, dtype=np.int),
368-
np.asarray(i_indices, dtype=np.int),
369-
np.asarray(r_values, dtype=np.float),
367+
np.asarray(u_indices, dtype='int'),
368+
np.asarray(i_indices, dtype='int'),
369+
np.asarray(r_values, dtype='float'),
370370
)
371371

372372
timestamps = (
373-
np.fromiter((int(data[i][3]) for i in valid_idx), dtype=np.int)
373+
np.fromiter((int(data[i][3]) for i in valid_idx), dtype='int')
374374
if fmt == "UIRT"
375375
else None
376376
)
@@ -447,7 +447,7 @@ def idx_iter(self, idx_range, batch_size=1, shuffle=False):
447447
448448
Returns
449449
-------
450-
iterator : batch of indices (array of np.int)
450+
iterator : batch of indices (array of 'int')
451451
452452
"""
453453
indices = np.arange(idx_range)
@@ -481,8 +481,8 @@ def uir_iter(self, batch_size=1, shuffle=False, binary=False, num_zeros=0):
481481
482482
Returns
483483
-------
484-
iterator : batch of users (array of np.int), batch of items (array of np.int),
485-
batch of ratings (array of np.float)
484+
iterator : batch of users (array of 'int'), batch of items (array of 'int'),
485+
batch of ratings (array of 'float')
486486
487487
"""
488488
for batch_ids in self.idx_iter(len(self.uir_tuple[0]), batch_size, shuffle):
@@ -524,8 +524,8 @@ def uij_iter(self, batch_size=1, shuffle=False, neg_sampling="uniform"):
524524
525525
Returns
526526
-------
527-
iterator : batch of users (array of np.int), batch of positive items (array of np.int),
528-
batch of negative items (array of np.int)
527+
iterator : batch of users (array of 'int'), batch of positive items (array of 'int'),
528+
batch of negative items (array of 'int')
529529
530530
"""
531531

@@ -562,9 +562,9 @@ def user_iter(self, batch_size=1, shuffle=False):
562562
563563
Returns
564564
-------
565-
iterator : batch of user indices (array of np.int)
565+
iterator : batch of user indices (array of 'int')
566566
"""
567-
user_indices = np.fromiter(self.user_indices, dtype=np.int)
567+
user_indices = np.fromiter(self.user_indices, dtype='int')
568568
for batch_ids in self.idx_iter(len(user_indices), batch_size, shuffle):
569569
yield user_indices[batch_ids]
570570

@@ -580,9 +580,9 @@ def item_iter(self, batch_size=1, shuffle=False):
580580
581581
Returns
582582
-------
583-
iterator : batch of item indices (array of np.int)
583+
iterator : batch of item indices (array of 'int')
584584
"""
585-
item_indices = np.fromiter(self.item_indices, np.int)
585+
item_indices = np.fromiter(self.item_indices, 'int')
586586
for batch_ids in self.idx_iter(len(item_indices), batch_size, shuffle):
587587
yield item_indices[batch_ids]
588588

cornac/data/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def _build_triplet(self, id_map):
6161
self.map_cid.append(id_map[j])
6262
self.val.append(v)
6363

64-
self.map_rid = np.asarray(self.map_rid, dtype=np.int)
65-
self.map_cid = np.asarray(self.map_cid, dtype=np.int)
66-
self.val = np.asarray(self.val, dtype=np.float)
64+
self.map_rid = np.asarray(self.map_rid, dtype='int')
65+
self.map_cid = np.asarray(self.map_cid, dtype='int')
66+
self.val = np.asarray(self.val, dtype='float')
6767

6868
def build(self, id_map=None, **kwargs):
6969
super().build(id_map=id_map)

cornac/data/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def batch_seq(self, batch_ids, max_length=None):
914914
if max_length is None:
915915
max_length = max(len(self.sequences[mapped_id]) for mapped_id in batch_ids)
916916

917-
seq_mat = np.zeros((len(batch_ids), max_length), dtype=np.int)
917+
seq_mat = np.zeros((len(batch_ids), max_length), dtype='int')
918918
for i, mapped_id in enumerate(batch_ids):
919919
idx_seq = self.sequences[mapped_id][:max_length]
920920
for j, idx in enumerate(idx_seq):

cornac/eval_methods/base_method.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def rating_eval(model, metrics, test_set, user_based=False, verbose=False):
7979
miniters=100,
8080
total=len(u_indices),
8181
),
82-
dtype=np.float,
82+
dtype='float',
8383
)
8484

8585
gt_mat = test_set.csr_matrix
@@ -177,7 +177,7 @@ def pos_items(csr_row):
177177
if len(test_pos_items) == 0:
178178
continue
179179

180-
u_gt_pos = np.zeros(test_set.num_items, dtype=np.int)
180+
u_gt_pos = np.zeros(test_set.num_items, dtype='int')
181181
u_gt_pos[test_pos_items] = 1
182182

183183
val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx))
@@ -187,7 +187,7 @@ def pos_items(csr_row):
187187
else pos_items(train_mat.getrow(user_idx))
188188
)
189189

190-
u_gt_neg = np.ones(test_set.num_items, dtype=np.int)
190+
u_gt_neg = np.ones(test_set.num_items, dtype='int')
191191
u_gt_neg[test_pos_items + val_pos_items + train_pos_items] = 0
192192

193193
item_indices = None if exclude_unknowns else np.arange(test_set.num_items)

cornac/eval_methods/propensity_stratified_evaluation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def pos_items(csr_row):
8787
if len(test_pos_items) == 0:
8888
continue
8989

90-
u_gt_pos = np.zeros(test_set.num_items, dtype=np.float)
90+
u_gt_pos = np.zeros(test_set.num_items, dtype='float')
9191
u_gt_pos[test_pos_items] = 1
9292

9393
val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx))
@@ -97,7 +97,7 @@ def pos_items(csr_row):
9797
else pos_items(train_mat.getrow(user_idx))
9898
)
9999

100-
u_gt_neg = np.ones(test_set.num_items, dtype=np.int)
100+
u_gt_neg = np.ones(test_set.num_items, dtype='int')
101101
u_gt_neg[test_pos_items + val_pos_items + train_pos_items] = 0
102102

103103
item_indices = None if exclude_unknowns else np.arange(test_set.num_items)
@@ -256,7 +256,7 @@ def _estimate_propensities(self):
256256
item_freq[i] += 1
257257

258258
# fit the exponential param
259-
data = np.array([e for e in item_freq.values()], dtype=np.float)
259+
data = np.array([e for e in item_freq.values()], dtype='float')
260260
results = powerlaw.Fit(data, discrete=True, fit_method="Likelihood")
261261
alpha = results.power_law.alpha
262262
fmin = results.power_law.xmin
@@ -277,7 +277,7 @@ def _build_stratified_dataset(self, test_data):
277277

278278
# match the corresponding propensity score for each feedback
279279
test_props = np.array(
280-
[self.props[i] for u, i, r in test_data], dtype=np.float64
280+
[self.props[i] for u, i, r in test_data], dtype='float'
281281
)
282282

283283
# stratify

cornac/metrics/ranking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,8 @@ def compute(self, pd_scores, gt_pos, gt_neg=None, **kwargs):
437437
if gt_neg is None:
438438
gt_neg = np.logical_not(gt_pos)
439439

440-
pos_scores = pd_scores[gt_pos.astype(np.bool)]
441-
neg_scores = pd_scores[gt_neg.astype(np.bool)]
440+
pos_scores = pd_scores[gt_pos.astype('bool')]
441+
neg_scores = pd_scores[gt_neg.astype('bool')]
442442
ui_scores = np.repeat(pos_scores, len(neg_scores))
443443
uj_scores = np.tile(neg_scores, len(pos_scores))
444444

@@ -476,7 +476,7 @@ def compute(self, pd_scores, gt_pos, **kwargs):
476476
AP score.
477477
478478
"""
479-
relevant = gt_pos.astype(np.bool)
479+
relevant = gt_pos.astype('bool')
480480
rank = rankdata(-pd_scores, "max")[relevant]
481481
L = rankdata(-pd_scores[relevant], "max")
482482
ans = (L / rank).mean()

cornac/models/cdl/recom_cdl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def fit(self, train_set, val_set=None):
181181
def _fit_cdl(self):
182182
import tensorflow.compat.v1 as tf
183183
from .cdl import Model
184+
185+
tf.disable_eager_execution()
184186

185187
R = self.train_set.csc_matrix # csc for efficient slicing over items
186188
n_users, n_items = self.train_set.num_users, self.train_set.num_items

cornac/models/cdl/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
tensorflow==1.15.2
1+
tensorflow==2.12.0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ============================================================================
1515
"""Collaborative Deep Ranking model"""
1616

17-
import tensorflow as tf
17+
import tensorflow.compat.v1 as tf
1818

1919
from ..cdl.cdl import sdae
2020

cornac/models/cdr/recom_cdr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,10 @@ def fit(self, train_set, val_set=None):
168168
return self
169169

170170
def _fit_cdr(self):
171-
import tensorflow as tf
172-
from .model import Model
171+
import tensorflow.compat.v1 as tf
172+
from .cdr import Model
173+
174+
tf.disable_eager_execution()
173175

174176
n_users = self.train_set.num_users
175177
n_items = self.train_set.num_items

0 commit comments

Comments
 (0)