forked from alibaba/EasyRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdssm_senet.py
More file actions
156 lines (130 loc) · 5.8 KB
/
dssm_senet.py
File metadata and controls
156 lines (130 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from easy_rec.python.layers import dnn
from easy_rec.python.layers import senet
from easy_rec.python.model.dssm import DSSM
from easy_rec.python.model.match_model import MatchModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.simi_pb2 import Similarity
from easy_rec.python.utils.proto_util import copy_obj
from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
losses = tf.losses
class DSSM_SENet(DSSM):
def __init__(self,
model_config,
feature_configs,
features,
labels=None,
is_training=False):
MatchModel.__init__(self, model_config, feature_configs, features, labels,
is_training)
assert self._model_config.WhichOneof('model') == 'dssm_senet', \
'invalid model config: %s' % self._model_config.WhichOneof('model')
self._model_config = self._model_config.dssm_senet
assert isinstance(self._model_config, DSSM_SENet_Config)
# copy_obj so that any modification will not affect original config
self.user_tower = copy_obj(self._model_config.user_tower)
self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer(
self._feature_dict, 'user', is_combine=False)
self.user_num_fields = len(self.user_feature_list)
# copy_obj so that any modification will not affect original config
self.item_tower = copy_obj(self._model_config.item_tower)
self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer(
self._feature_dict, 'item', is_combine=False)
self.item_num_fields = len(self.item_feature_list)
self._user_tower_emb = None
self._item_tower_emb = None
def build_predict_graph(self):
user_senet = senet.SENet(
num_fields=self.user_num_fields,
num_squeeze_group=self.user_tower.senet.num_squeeze_group,
reduction_ratio=self.user_tower.senet.reduction_ratio,
excitation_acitvation=self.user_tower.senet.excitation_acitvation,
l2_reg=self._l2_reg,
name='user_senet')
user_senet_output_list = user_senet(self.user_feature_list)
user_senet_output = tf.concat(user_senet_output_list, axis=-1)
user_senet_output = tf.layers.batch_normalization(
user_senet_output,
training=self._is_training,
trainable=True,
name='user_senet_bn')
num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
last_user_hidden = self.user_tower.dnn.hidden_units.pop()
user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
self._is_training)
user_tower_emb = user_dnn(user_senet_output)
user_tower_emb = tf.layers.dense(
inputs=user_tower_emb,
units=last_user_hidden,
kernel_regularizer=self._l2_reg,
name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
item_senet = senet.SENet(
num_fields=self.item_num_fields,
num_squeeze_group=self.item_tower.senet.num_squeeze_group,
reduction_ratio=self.item_tower.senet.reduction_ratio,
excitation_acitvation=self.item_tower.senet.excitation_acitvation,
l2_reg=self._l2_reg,
name='item_senet')
item_senet_output_list = item_senet(self.item_feature_list)
item_senet_output = tf.concat(item_senet_output_list, axis=-1)
item_senet_output = tf.layers.batch_normalization(
item_senet_output,
training=self._is_training,
trainable=True,
name='item_senet_bn')
num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
last_item_hidden = self.item_tower.dnn.hidden_units.pop()
item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
self._is_training)
item_tower_emb = item_dnn(item_senet_output)
item_tower_emb = tf.layers.dense(
inputs=item_tower_emb,
units=last_item_hidden,
kernel_regularizer=self._l2_reg,
name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
if self._model_config.simi_func == Similarity.COSINE:
user_tower_emb = self.norm(user_tower_emb)
item_tower_emb = self.norm(item_tower_emb)
temperature = self._model_config.temperature
else:
temperature = 1.0
user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
if self._model_config.scale_simi:
sim_w = tf.get_variable(
'sim_w',
dtype=tf.float32,
shape=(1),
initializer=tf.ones_initializer())
sim_b = tf.get_variable(
'sim_b',
dtype=tf.float32,
shape=(1),
initializer=tf.zeros_initializer())
y_pred = user_item_sim * tf.abs(sim_w) + sim_b
else:
y_pred = user_item_sim
if self._is_point_wise:
y_pred = tf.reshape(y_pred, [-1])
if self._loss_type == LossType.CLASSIFICATION:
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
y_pred = self._mask_in_batch(y_pred)
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
self._prediction_dict['y'] = y_pred
self._prediction_dict['user_tower_emb'] = user_tower_emb
self._prediction_dict['item_tower_emb'] = item_tower_emb
self._prediction_dict['user_emb'] = tf.reduce_join(
tf.as_string(user_tower_emb), axis=-1, separator=',')
self._prediction_dict['item_emb'] = tf.reduce_join(
tf.as_string(item_tower_emb), axis=-1, separator=',')
return self._prediction_dict
def build_output_dict(self):
output_dict = MatchModel.build_output_dict(self)
return output_dict