Skip to content

Commit 3ef1678

Browse files
authored
Fixed save and load NARRE model (#517)
Fixed save and load NARRE model
1 parent 042321a commit 3ef1678

File tree

2 files changed

+131
-106
lines changed

2 files changed

+131
-106
lines changed

cornac/models/narre/narre.py

Lines changed: 92 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import tensorflow as tf
1818
from tensorflow import keras
1919
from tensorflow.keras import layers, initializers, Input
20+
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
2021

2122
from ...utils import get_rng
2223
from ...utils.init_utils import uniform
2324

2425

2526
class TextProcessor(keras.Model):
26-
def __init__(self, max_text_length, filters=64, kernel_sizes=[3], dropout_rate=0.5, name=''):
27-
super(TextProcessor, self).__init__(name=name)
27+
def __init__(self, max_text_length, filters=64, kernel_sizes=[3], dropout_rate=0.5, name='', **kwargs):
28+
super(TextProcessor, self).__init__(name=name, **kwargs)
2829
self.max_text_length = max_text_length
2930
self.filters = filters
3031
self.kernel_sizes = kernel_sizes
@@ -51,7 +52,6 @@ def call(self, inputs, training=False):
5152

5253

5354
def get_data(batch_ids, train_set, max_text_length, by='user', max_num_review=None):
54-
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
5555
batch_reviews, batch_id_reviews, batch_num_reviews = [], [], []
5656
review_group = train_set.review_text.user_review if by == 'user' else train_set.review_text.item_review
5757
for idx in batch_ids:
@@ -65,8 +65,8 @@ def get_data(batch_ids, train_set, max_text_length, by='user', max_num_review=No
6565
reviews = train_set.review_text.batch_seq(review_ids, max_length=max_text_length)
6666
batch_reviews.append(reviews)
6767
batch_num_reviews.append(len(reviews))
68-
batch_reviews = pad_sequences(batch_reviews, padding="post")
69-
batch_id_reviews = pad_sequences(batch_id_reviews, padding="post")
68+
batch_reviews = pad_sequences(batch_reviews, maxlen=max_num_review, padding="post")
69+
batch_id_reviews = pad_sequences(batch_id_reviews, maxlen=max_num_review, padding="post")
7070
batch_num_reviews = np.array(batch_num_reviews)
7171
return batch_reviews, batch_id_reviews, batch_num_reviews
7272

@@ -80,13 +80,69 @@ def __init__(self, init_value=0.0, name="global_bias"):
8080
def build(self, input_shape):
8181
self.global_bias = self.add_weight(shape=1,
8282
initializer=tf.keras.initializers.Constant(self.init_value),
83-
trainable=True)
83+
trainable=True, name="add_weight")
8484

8585
def call(self, inputs):
8686
return inputs + self.global_bias
8787

88-
class Model:
89-
def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding_size=100, id_embedding_size=32, attention_size=16, kernel_sizes=[3], n_filters=64, dropout_rate=0.5, max_text_length=50, pretrained_word_embeddings=None, verbose=False, seed=None):
88+
class Model(keras.Model):
89+
def __init__(self, n_users, n_items, n_vocab, embedding_matrix, global_mean, n_factors=32, embedding_size=100, id_embedding_size=32, attention_size=16, kernel_sizes=[3], n_filters=64, dropout_rate=0.5, max_text_length=50):
90+
super().__init__()
91+
self.l_user_review_embedding = layers.Embedding(n_vocab, embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_user_review_embedding")
92+
self.l_item_review_embedding = layers.Embedding(n_vocab, embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_item_review_embedding")
93+
self.l_user_iid_embedding = layers.Embedding(n_items, id_embedding_size, embeddings_initializer="uniform", name="user_iid_embedding")
94+
self.l_item_uid_embedding = layers.Embedding(n_users, id_embedding_size, embeddings_initializer="uniform", name="item_uid_embedding")
95+
self.l_user_embedding = layers.Embedding(n_users, id_embedding_size, embeddings_initializer="uniform", name="user_embedding")
96+
self.l_item_embedding = layers.Embedding(n_items, id_embedding_size, embeddings_initializer="uniform", name="item_embedding")
97+
self.user_bias = layers.Embedding(n_users, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="user_bias")
98+
self.item_bias = layers.Embedding(n_items, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="item_bias")
99+
self.user_text_processor = TextProcessor(max_text_length, filters=n_filters, kernel_sizes=kernel_sizes, dropout_rate=dropout_rate, name='user_text_processor')
100+
self.item_text_processor = TextProcessor(max_text_length, filters=n_filters, kernel_sizes=kernel_sizes, dropout_rate=dropout_rate, name='item_text_processor')
101+
self.a_user = keras.models.Sequential([
102+
layers.Dense(attention_size, activation="relu", use_bias=True),
103+
layers.Dense(1, activation=None, use_bias=True)
104+
])
105+
self.user_attention = layers.Softmax(axis=1, name="user_attention")
106+
self.a_item = keras.models.Sequential([
107+
layers.Dense(attention_size, activation="relu", use_bias=True),
108+
layers.Dense(1, activation=None, use_bias=True)
109+
])
110+
self.item_attention = layers.Softmax(axis=1, name="item_attention")
111+
self.user_Oi_dropout = layers.Dropout(rate=dropout_rate, name="user_Oi")
112+
self.Xu = layers.Dense(n_factors, use_bias=True, name="Xu")
113+
self.item_Oi_dropout = layers.Dropout(rate=dropout_rate, name="item_Oi")
114+
self.Yi = layers.Dense(n_factors, use_bias=True, name="Yi")
115+
116+
self.W1 = layers.Dense(1, activation=None, use_bias=False, name="W1")
117+
self.add_global_bias = AddGlobalBias(init_value=global_mean, name="global_bias")
118+
119+
def call(self, inputs, training=None):
120+
i_user_id, i_item_id, i_user_review, i_user_iid_review, i_user_num_reviews, i_item_review, i_item_uid_review, i_item_num_reviews = inputs
121+
user_review_h = self.user_text_processor(self.l_user_review_embedding(i_user_review), training=training)
122+
a_user = self.a_user(tf.concat([user_review_h, self.l_user_iid_embedding(i_user_iid_review)], axis=-1))
123+
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
124+
user_attention = self.user_attention(a_user, a_user_masking)
125+
user_Oi = self.user_Oi_dropout(tf.reduce_sum(tf.multiply(user_attention, user_review_h), 1), training=training)
126+
Xu = self.Xu(user_Oi)
127+
item_review_h = self.item_text_processor(self.l_item_review_embedding(i_item_review), training=training)
128+
a_item = self.a_item(tf.concat([item_review_h, self.l_item_uid_embedding(i_item_uid_review)], axis=-1))
129+
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
130+
item_attention = self.item_attention(a_item, a_item_masking)
131+
item_Oi = self.item_Oi_dropout(tf.reduce_sum(tf.multiply(item_attention, item_review_h), 1), training=training)
132+
Yi = self.Yi(item_Oi)
133+
h0 = tf.multiply(tf.add(self.l_user_embedding(i_user_id), Xu), tf.add(self.l_item_embedding(i_item_id), Yi))
134+
r = self.add_global_bias(
135+
tf.add_n([
136+
self.W1(h0),
137+
self.user_bias(i_user_id),
138+
self.item_bias(i_item_id)
139+
])
140+
)
141+
# import pdb; pdb.set_trace()
142+
return r
143+
144+
class NARREModel:
145+
def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding_size=100, id_embedding_size=32, attention_size=16, kernel_sizes=[3], n_filters=64, dropout_rate=0.5, max_text_length=50, max_num_review=32, pretrained_word_embeddings=None, verbose=False, seed=None):
90146
self.n_users = n_users
91147
self.n_items = n_items
92148
self.n_vocab = vocab.size
@@ -99,6 +155,7 @@ def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding
99155
self.n_filters = n_filters
100156
self.dropout_rate = dropout_rate
101157
self.max_text_length = max_text_length
158+
self.max_num_review = max_num_review
102159
self.verbose = verbose
103160
if seed is not None:
104161
self.rng = get_rng(seed)
@@ -118,88 +175,39 @@ def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding
118175
print("Number of OOV words: %d" % oov_count)
119176

120177
embedding_matrix = initializers.Constant(embedding_matrix)
121-
i_user_id = Input(shape=(1,), dtype="int32", name="input_user_id")
122-
i_item_id = Input(shape=(1,), dtype="int32", name="input_item_id")
123-
i_user_review = Input(shape=(None, self.max_text_length), dtype="int32", name="input_user_review")
124-
i_item_review = Input(shape=(None, self.max_text_length), dtype="int32", name="input_item_review")
125-
i_user_iid_review = Input(shape=(None,), dtype="int32", name="input_user_iid_review")
126-
i_item_uid_review = Input(shape=(None,), dtype="int32", name="input_item_uid_review")
127-
i_user_num_reviews = Input(shape=(1,), dtype="int32", name="input_user_number_of_review")
128-
i_item_num_reviews = Input(shape=(1,), dtype="int32", name="input_item_number_of_review")
129-
130-
l_user_review_embedding = layers.Embedding(self.n_vocab, self.embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_user_review_embedding")
131-
l_item_review_embedding = layers.Embedding(self.n_vocab, self.embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_item_review_embedding")
132-
l_user_iid_embedding = layers.Embedding(self.n_items, self.id_embedding_size, embeddings_initializer="uniform", name="user_iid_embedding")
133-
l_item_uid_embedding = layers.Embedding(self.n_users, self.id_embedding_size, embeddings_initializer="uniform", name="item_uid_embedding")
134-
l_user_embedding = layers.Embedding(self.n_users, self.id_embedding_size, embeddings_initializer="uniform", name="user_embedding")
135-
l_item_embedding = layers.Embedding(self.n_items, self.id_embedding_size, embeddings_initializer="uniform", name="item_embedding")
136-
user_bias = layers.Embedding(self.n_users, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="user_bias")
137-
item_bias = layers.Embedding(self.n_items, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="item_bias")
138-
139-
user_text_processor = TextProcessor(self.max_text_length, filters=self.n_filters, kernel_sizes=self.kernel_sizes, dropout_rate=self.dropout_rate, name='user_text_processor')
140-
item_text_processor = TextProcessor(self.max_text_length, filters=self.n_filters, kernel_sizes=self.kernel_sizes, dropout_rate=self.dropout_rate, name='item_text_processor')
141-
142-
user_review_h = user_text_processor(l_user_review_embedding(i_user_review), training=True)
143-
item_review_h = item_text_processor(l_item_review_embedding(i_item_review), training=True)
144-
a_user = layers.Dense(1, activation=None, use_bias=True)(
145-
layers.Dense(self.attention_size, activation="relu", use_bias=True)(
146-
tf.concat([user_review_h, l_user_iid_embedding(i_user_iid_review)], axis=-1)
147-
)
148-
)
149-
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
150-
user_attention = layers.Softmax(axis=1, name="user_attention")(a_user, a_user_masking)
151-
a_item = layers.Dense(1, activation=None, use_bias=True)(
152-
layers.Dense(self.attention_size, activation="relu", use_bias=True)(
153-
tf.concat([item_review_h, l_item_uid_embedding(i_item_uid_review)], axis=-1)
154-
)
155-
)
156-
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
157-
item_attention = layers.Softmax(axis=1, name="item_attention")(a_item, a_item_masking)
158-
159-
Xu = layers.Dense(self.n_factors, use_bias=True, name="Xu")(
160-
layers.Dropout(rate=self.dropout_rate, name="user_Oi")(
161-
tf.reduce_sum(layers.Multiply()([user_attention, user_review_h]), 1)
162-
)
163-
)
164-
Yi = layers.Dense(self.n_factors, use_bias=True, name="Yi")(
165-
layers.Dropout(rate=self.dropout_rate, name="item_Oi")(
166-
tf.reduce_sum(layers.Multiply()([item_attention, item_review_h]), 1)
167-
)
178+
self.graph = Model(
179+
self.n_users, self.n_items, self.n_vocab, embedding_matrix, self.global_mean,
180+
self.n_factors, self.embedding_size, self.id_embedding_size, self.attention_size,
181+
self.kernel_sizes, self.n_filters, self.dropout_rate, self.max_text_length
168182
)
169183

170-
h0 = layers.Multiply(name="h0")([
171-
layers.Add()([l_user_embedding(i_user_id), Xu]), layers.Add()([l_item_embedding(i_item_id), Yi])
172-
])
173-
174-
W1 = layers.Dense(1, activation=None, use_bias=False, name="W1")
175-
add_global_bias = AddGlobalBias(init_value=self.global_mean, name="global_bias")
176-
r = layers.Add(name="prediction")([
177-
W1(h0),
178-
user_bias(i_user_id),
179-
item_bias(i_item_id)
180-
])
181-
r = add_global_bias(r)
182-
self.graph = keras.Model(inputs=[i_user_id, i_item_id, i_user_review, i_user_iid_review, i_user_num_reviews, i_item_review, i_item_uid_review, i_item_num_reviews], outputs=r)
183-
if self.verbose:
184-
self.graph.summary()
185-
186-
def get_weights(self, train_set, batch_size=64, max_num_review=None):
187-
user_attention_review_pooling = keras.Model(inputs=[self.graph.get_layer('input_user_review').input, self.graph.get_layer('input_user_iid_review').input, self.graph.get_layer('input_user_number_of_review').input], outputs=self.graph.get_layer('Xu').output)
188-
item_attention_review_pooling = keras.Model(inputs=[self.graph.get_layer('input_item_review').input, self.graph.get_layer('input_item_uid_review').input, self.graph.get_layer('input_item_number_of_review').input], outputs=self.graph.get_layer('Yi').output)
184+
def get_weights(self, train_set, batch_size=64):
189185
X = np.zeros((self.n_users, self.n_factors))
190186
Y = np.zeros((self.n_items, self.n_factors))
191187
for batch_users in train_set.user_iter(batch_size):
192-
user_reviews, user_iid_reviews, user_num_reviews = get_data(batch_users, train_set, self.max_text_length, by='user', max_num_review=max_num_review)
193-
Xu = user_attention_review_pooling([user_reviews, user_iid_reviews, user_num_reviews], training=False)
188+
i_user_review, i_user_iid_review, i_user_num_reviews = get_data(batch_users, train_set, self.max_text_length, by='user', max_num_review=self.max_num_review)
189+
user_review_embedding = self.graph.l_user_review_embedding(i_user_review)
190+
user_review_h = self.graph.user_text_processor(user_review_embedding, training=False)
191+
a_user = self.graph.a_user(tf.concat([user_review_h, self.graph.l_user_iid_embedding(i_user_iid_review)], axis=-1))
192+
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
193+
user_attention = self.graph.user_attention(a_user, a_user_masking)
194+
user_Oi = tf.reduce_sum(tf.multiply(user_attention, user_review_h), 1)
195+
Xu = self.graph.Xu(user_Oi)
194196
X[batch_users] = Xu.numpy()
195197
for batch_items in train_set.item_iter(batch_size):
196-
item_reviews, item_uid_reviews, item_num_reviews = get_data(batch_items, train_set, self.max_text_length, by='item', max_num_review=max_num_review)
197-
Yi = item_attention_review_pooling([item_reviews, item_uid_reviews, item_num_reviews], training=False)
198+
i_item_review, i_item_uid_review, i_item_num_reviews = get_data(batch_items, train_set, self.max_text_length, by='item', max_num_review=self.max_num_review)
199+
item_review_embedding = self.graph.l_item_review_embedding(i_item_review)
200+
item_review_h = self.graph.item_text_processor(item_review_embedding, training=False)
201+
a_item = self.graph.a_item(tf.concat([item_review_h, self.graph.l_item_uid_embedding(i_item_uid_review)], axis=-1))
202+
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
203+
item_attention = self.graph.item_attention(a_item, a_item_masking)
204+
item_Oi = tf.reduce_sum(tf.multiply(item_attention, item_review_h), 1)
205+
Yi = self.graph.Yi(item_Oi)
198206
Y[batch_items] = Yi.numpy()
199-
W1 = self.graph.get_layer('W1').get_weights()[0]
200-
user_embedding = self.graph.get_layer('user_embedding').get_weights()[0]
201-
item_embedding = self.graph.get_layer('item_embedding').get_weights()[0]
202-
bu = self.graph.get_layer('user_bias').get_weights()[0]
203-
bi = self.graph.get_layer('item_bias').get_weights()[0]
204-
mu = self.graph.get_layer('global_bias').get_weights()[0][0]
207+
W1 = self.graph.W1.get_weights()[0]
208+
user_embedding = self.graph.l_user_embedding.get_weights()[0]
209+
item_embedding = self.graph.l_item_embedding.get_weights()[0]
210+
bu = self.graph.user_bias.get_weights()[0]
211+
bi = self.graph.item_bias.get_weights()[0]
212+
mu = self.graph.add_global_bias.get_weights()[0][0]
205213
return X, Y, W1, user_embedding, item_embedding, bu, bi, mu

0 commit comments

Comments
 (0)