|
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================ |
15 | 15 |
|
16 | | -import warnings |
17 | 16 |
|
18 | | -# disable annoying tensorflow deprecated API warnings |
19 | | -warnings.filterwarnings("ignore", category=UserWarning) |
20 | | - |
21 | | -import tensorflow.compat.v1 as tf |
22 | | - |
23 | | -tf.logging.set_verbosity(tf.logging.ERROR) |
24 | | -tf.disable_v2_behavior() |
| 17 | +import tensorflow as tf |
25 | 18 |
|
26 | 19 |
|
27 | 20 | act_functions = { |
|
35 | 28 | } |
36 | 29 |
|
37 | 30 |
|
38 | | -def loss_fn(labels, logits): |
39 | | - cross_entropy = tf.reduce_mean( |
40 | | - tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) |
41 | | - ) |
42 | | - reg_loss = tf.losses.get_regularization_loss() |
43 | | - return cross_entropy + reg_loss |
44 | | - |
45 | | - |
46 | | -def train_fn(loss, learning_rate, learner): |
| 31 | +def get_optimizer(learning_rate, learner): |
47 | 32 | if learner.lower() == "adagrad": |
48 | | - opt = tf.train.AdagradOptimizer(learning_rate=learning_rate, name="optimizer") |
| 33 | + return tf.keras.optimizers.Adagrad(learning_rate=learning_rate) |
49 | 34 | elif learner.lower() == "rmsprop": |
50 | | - opt = tf.train.RMSPropOptimizer(learning_rate=learning_rate, name="optimizer") |
| 35 | + return tf.keras.optimizers.RMSprop(learning_rate=learning_rate) |
51 | 36 | elif learner.lower() == "adam": |
52 | | - opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name="optimizer") |
| 37 | + return tf.keras.optimizers.Adam(learning_rate=learning_rate) |
53 | 38 | else: |
54 | | - opt = tf.train.GradientDescentOptimizer( |
55 | | - learning_rate=learning_rate, name="optimizer" |
56 | | - ) |
57 | | - |
58 | | - return opt.minimize(loss) |
59 | | - |
60 | | - |
61 | | -def emb( |
62 | | - uid, iid, num_users, num_items, emb_size, reg_user, reg_item, seed=None, scope="emb" |
63 | | -): |
64 | | - with tf.variable_scope(scope): |
65 | | - user_emb = tf.get_variable( |
66 | | - "user_emb", |
67 | | - shape=[num_users, emb_size], |
68 | | - dtype=tf.float32, |
69 | | - initializer=tf.random_normal_initializer(stddev=0.01, seed=seed), |
70 | | - regularizer=tf.keras.regularizers.L2(reg_user), |
| 39 | + return tf.keras.optimizers.SGD(learning_rate=learning_rate) |
| 40 | + |
| 41 | + |
| 42 | +class GMFLayer(tf.keras.layers.Layer): |
| 43 | + def __init__(self, num_users, num_items, emb_size, reg_user, reg_item, seed=None, **kwargs): |
| 44 | + super(GMFLayer, self).__init__(**kwargs) |
| 45 | + self.num_users = num_users |
| 46 | + self.num_items = num_items |
| 47 | + self.emb_size = emb_size |
| 48 | + self.reg_user = reg_user |
| 49 | + self.reg_item = reg_item |
| 50 | + self.seed = seed |
| 51 | + |
| 52 | + # Initialize embeddings |
| 53 | + self.user_embedding = tf.keras.layers.Embedding( |
| 54 | + num_users, |
| 55 | + emb_size, |
| 56 | + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), |
| 57 | + embeddings_regularizer=tf.keras.regularizers.L2(reg_user), |
| 58 | + name="user_embedding" |
71 | 59 | ) |
72 | | - item_emb = tf.get_variable( |
73 | | - "item_emb", |
74 | | - shape=[num_items, emb_size], |
75 | | - dtype=tf.float32, |
76 | | - initializer=tf.random_normal_initializer(stddev=0.01, seed=seed), |
77 | | - regularizer=tf.keras.regularizers.L2(reg_item), |
78 | | - ) |
79 | | - |
80 | | - return tf.nn.embedding_lookup(user_emb, uid), tf.nn.embedding_lookup(item_emb, iid) |
81 | | - |
82 | | - |
83 | | -def gmf(uid, iid, num_users, num_items, emb_size, reg_user, reg_item, seed=None): |
84 | | - with tf.variable_scope("GMF") as scope: |
85 | | - user_emb, item_emb = emb( |
86 | | - uid=uid, |
87 | | - iid=iid, |
88 | | - num_users=num_users, |
89 | | - num_items=num_items, |
90 | | - emb_size=emb_size, |
91 | | - reg_user=reg_user, |
92 | | - reg_item=reg_item, |
93 | | - seed=seed, |
94 | | - scope=scope, |
| 60 | + |
| 61 | + self.item_embedding = tf.keras.layers.Embedding( |
| 62 | + num_items, |
| 63 | + emb_size, |
| 64 | + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), |
| 65 | + embeddings_regularizer=tf.keras.regularizers.L2(reg_item), |
| 66 | + name="item_embedding" |
95 | 67 | ) |
| 68 | + |
| 69 | + def call(self, inputs): |
| 70 | + user_ids, item_ids = inputs |
| 71 | + user_emb = self.user_embedding(user_ids) |
| 72 | + item_emb = self.item_embedding(item_ids) |
96 | 73 | return tf.multiply(user_emb, item_emb) |
97 | 74 |
|
98 | 75 |
|
99 | | -def mlp(uid, iid, num_users, num_items, layers, reg_layers, act_fn, seed=None): |
100 | | - with tf.variable_scope("MLP") as scope: |
101 | | - user_emb, item_emb = emb( |
102 | | - uid=uid, |
103 | | - iid=iid, |
104 | | - num_users=num_users, |
105 | | - num_items=num_items, |
106 | | - emb_size=int(layers[0] / 2), |
107 | | - reg_user=reg_layers[0], |
108 | | - reg_item=reg_layers[0], |
109 | | - seed=seed, |
110 | | - scope=scope, |
| 76 | +class MLPLayer(tf.keras.layers.Layer): |
| 77 | + def __init__(self, num_users, num_items, layers, reg_layers, act_fn, seed=None, **kwargs): |
| 78 | + super(MLPLayer, self).__init__(**kwargs) |
| 79 | + self.num_users = num_users |
| 80 | + self.num_items = num_items |
| 81 | + self.layers = layers |
| 82 | + self.reg_layers = reg_layers |
| 83 | + self.act_fn = act_fn |
| 84 | + self.seed = seed |
| 85 | + |
| 86 | + # Initialize embeddings |
| 87 | + self.user_embedding = tf.keras.layers.Embedding( |
| 88 | + num_users, |
| 89 | + int(layers[0] / 2), |
| 90 | + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), |
| 91 | + embeddings_regularizer=tf.keras.regularizers.L2(reg_layers[0]), |
| 92 | + name="user_embedding" |
111 | 93 | ) |
112 | | - interaction = tf.concat([user_emb, item_emb], axis=-1) |
113 | | - for i, layer in enumerate(layers[1:]): |
114 | | - interaction = tf.layers.dense( |
115 | | - interaction, |
116 | | - units=layer, |
117 | | - name="layer{}".format(i + 1), |
118 | | - activation=act_functions.get(act_fn, tf.nn.relu), |
119 | | - kernel_initializer=tf.initializers.lecun_uniform(seed), |
120 | | - kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]), |
| 94 | + |
| 95 | + self.item_embedding = tf.keras.layers.Embedding( |
| 96 | + num_items, |
| 97 | + int(layers[0] / 2), |
| 98 | + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), |
| 99 | + embeddings_regularizer=tf.keras.regularizers.L2(reg_layers[0]), |
| 100 | + name="item_embedding" |
| 101 | + ) |
| 102 | + |
| 103 | + # Define dense layers |
| 104 | + self.dense_layers = [] |
| 105 | + for i, layer_size in enumerate(layers[1:]): |
| 106 | + self.dense_layers.append( |
| 107 | + tf.keras.layers.Dense( |
| 108 | + layer_size, |
| 109 | + activation=act_functions.get(act_fn, tf.nn.relu), |
| 110 | + kernel_initializer=tf.keras.initializers.LecunUniform(seed=seed), |
| 111 | + kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]), |
| 112 | + name=f"layer{i+1}" |
| 113 | + ) |
121 | 114 | ) |
| 115 | + |
| 116 | + def call(self, inputs): |
| 117 | + user_ids, item_ids = inputs |
| 118 | + user_emb = self.user_embedding(user_ids) |
| 119 | + item_emb = self.item_embedding(item_ids) |
| 120 | + interaction = tf.concat([user_emb, item_emb], axis=-1) |
| 121 | + |
| 122 | + for layer in self.dense_layers: |
| 123 | + interaction = layer(interaction) |
| 124 | + |
122 | 125 | return interaction |
0 commit comments