Skip to content

Commit d2ba48b

Browse files
authored
Merge pull request #483 from thinkall/master
add support to change history length and embedding size for DMR
2 parents be3b88d + ad31446 commit d2ba48b

File tree

1 file changed

+61
-40
lines changed

1 file changed

+61
-40
lines changed

models/rank/dmr/net.py

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, user_size, cms_segid_size, cms_group_id_size,
4646
self.pid_size = pid_size
4747
self.main_embedding_size = main_embedding_size
4848
self.other_embedding_size = other_embedding_size
49+
self.history_length = 50
4950

5051
self.uid_embeddings_var = paddle.nn.Embedding(
5152
self.user_size,
@@ -183,38 +184,48 @@ def __init__(self, user_size, cms_segid_size, cms_group_id_size,
183184
name="PidSparseFeatFactors",
184185
initializer=nn.initializer.Uniform()))
185186

186-
self.position_his = paddle.arange(0, 50)
187+
self.position_his = paddle.arange(0, self.history_length)
187188
self.position_embeddings_var = paddle.nn.Embedding(
188-
50,
189+
self.history_length,
189190
self.other_embedding_size,
190191
# sparse=True,
191192
weight_attr=paddle.ParamAttr(
192193
name="PosSparseFeatFactors",
193194
initializer=nn.initializer.Uniform()))
194195

195-
self.dm_position_his = paddle.arange(0, 50)
196+
self.dm_position_his = paddle.arange(0, self.history_length)
196197
self.dm_position_embeddings_var = paddle.nn.Embedding(
197-
50,
198+
self.history_length,
198199
self.other_embedding_size,
199200
# sparse=True,
200201
weight_attr=paddle.ParamAttr(
201202
name="DmPosSparseFeatFactors",
202203
initializer=nn.initializer.Uniform()))
203-
self.query_layer = paddle.nn.Linear(16, 64, name='dm_align')
204+
self.query_layer = paddle.nn.Linear(
205+
self.other_embedding_size * 2,
206+
self.main_embedding_size * 2,
207+
name='dm_align')
204208
self.query_prelu = paddle.nn.PReLU(
205-
num_parameters=50, init=0.1, name='dm_prelu')
206-
self.att_layer1_layer = paddle.nn.Linear(256, 80, name='dm_att_1')
209+
num_parameters=self.history_length, init=0.1, name='dm_prelu')
210+
self.att_layer1_layer = paddle.nn.Linear(
211+
self.main_embedding_size * 8, 80, name='dm_att_1')
207212
self.att_layer2_layer = paddle.nn.Linear(80, 40, name='dm_att_2')
208213
self.att_layer3_layer = paddle.nn.Linear(40, 1, name='dm_att_3')
209214
self.dnn_layer1_layer = paddle.nn.Linear(
210-
64, self.main_embedding_size, name='dm_fcn_1')
215+
self.main_embedding_size * 2,
216+
self.main_embedding_size,
217+
name='dm_fcn_1')
211218
self.dnn_layer1_prelu = paddle.nn.PReLU(
212-
num_parameters=50, init=0.1, name='dm_fcn_1')
219+
num_parameters=self.history_length, init=0.1, name='dm_fcn_1')
213220

214-
self.query_layer2 = paddle.nn.Linear(80, 64, name='dmr_align')
221+
self.query_layer2 = paddle.nn.Linear(
222+
(self.other_embedding_size + self.main_embedding_size) * 2,
223+
self.main_embedding_size * 2,
224+
name='dmr_align')
215225
self.query_prelu2 = paddle.nn.PReLU(
216-
num_parameters=50, init=0.1, name='dmr_prelu')
217-
self.att_layer1_layer2 = paddle.nn.Linear(256, 80, name='tg_att_1')
226+
num_parameters=self.history_length, init=0.1, name='dmr_prelu')
227+
self.att_layer1_layer2 = paddle.nn.Linear(
228+
self.main_embedding_size * 8, 80, name='tg_att_1')
218229
self.att_layer2_layer2 = paddle.nn.Linear(80, 40, name='tg_att_2')
219230
self.att_layer3_layer2 = paddle.nn.Linear(40, 1, name='tg_att_3')
220231

@@ -224,7 +235,8 @@ def __init__(self, user_size, cms_segid_size, cms_group_id_size,
224235
def deep_match(item_his_eb, context_his_eb, mask, match_mask,
225236
mid_his_batch, item_vectors, item_biases, n_mid):
226237
query = context_his_eb
227-
query = self.query_layer(query) # [-1, 50, 64]
238+
query = self.query_layer(
239+
query) # [-1, self.history_length, self.main_embedding_size*2]
228240
query = self.query_prelu(query)
229241

230242
inputs = paddle.concat(
@@ -264,7 +276,8 @@ def deep_match(item_his_eb, context_his_eb, mask, match_mask,
264276
att_dm_item_his_eb = paddle.matmul(scores_tile,
265277
item_his_eb) # B, T, E
266278
dnn_layer1 = self.dnn_layer1_layer(att_dm_item_his_eb)
267-
dnn_layer1 = dnn_layer1.reshape([-1, 50, 32]) ##
279+
dnn_layer1 = dnn_layer1.reshape(
280+
[-1, self.history_length, self.main_embedding_size]) ##
268281
dnn_layer1 = self.dnn_layer1_prelu(dnn_layer1)
269282

270283
# target mask
@@ -316,7 +329,7 @@ def dmr_fcn_attention(item_eb,
316329
att_layer_3 = paddle.reshape(
317330
att_layer_3, [-1, 1, paddle.shape(item_his_eb)[1]]) # B,1,T
318331
scores = att_layer_3
319-
scores = scores.reshape([-1, 1, 50]) ##
332+
scores = scores.reshape([-1, 1, self.history_length]) ##
320333

321334
# Mask
322335
key_masks = paddle.unsqueeze(mask, 1) # B,1,T
@@ -353,8 +366,14 @@ def dmr_fcn_attention(item_eb,
353366
self.dm_item_biases = paddle.zeros(
354367
shape=[self.cate_size], dtype='float32')
355368

356-
self.inp_layer = paddle.nn.BatchNorm(459, momentum=0.99, epsilon=1e-03)
357-
self.dnn0_layer = paddle.nn.Linear(459, 512, name='f0')
369+
self.inp_length = self.main_embedding_size + (
370+
self.other_embedding_size * 8 + self.main_embedding_size * 5 + 1 +
371+
self.other_embedding_size + self.main_embedding_size * 2 +
372+
self.main_embedding_size * 2 + 1 + 1 + self.main_embedding_size * 2
373+
)
374+
self.inp_layer = paddle.nn.BatchNorm(
375+
self.inp_length, momentum=0.99, epsilon=1e-03)
376+
self.dnn0_layer = paddle.nn.Linear(self.inp_length, 512, name='f0')
358377
self.dnn0_prelu = paddle.nn.PReLU(
359378
num_parameters=512, init=0.1, name='prelu0')
360379
self.dnn1_layer = paddle.nn.Linear(512, 256, name='f1')
@@ -371,33 +390,35 @@ def forward(self, inputs_tensor, is_infer=0):
371390
# input
372391
inputs = inputs_tensor[0] # sparse_tensor
373392
dense_tensor = inputs_tensor[1]
374-
self.btag_his = inputs[:, 0:50]
375-
self.cate_his = inputs[:, 50:100]
376-
self.brand_his = inputs[:, 100:150]
377-
self.mask = inputs[:, 150:200]
378-
self.match_mask = inputs[:, 200:250]
379-
380-
self.uid = inputs[:, 250]
381-
self.cms_segid = inputs[:, 251]
382-
self.cms_group_id = inputs[:, 252]
383-
self.final_gender_code = inputs[:, 253]
384-
self.age_level = inputs[:, 254]
385-
self.pvalue_level = inputs[:, 255]
386-
self.shopping_level = inputs[:, 256]
387-
self.occupation = inputs[:, 257]
388-
self.new_user_class_level = inputs[:, 258]
389-
390-
self.mid = inputs[:, 259]
391-
self.cate_id = inputs[:, 260]
392-
self.campaign_id = inputs[:, 261]
393-
self.customer = inputs[:, 262]
394-
self.brand = inputs[:, 263]
393+
self.btag_his = inputs[:, 0:self.history_length]
394+
self.cate_his = inputs[:, self.history_length:self.history_length * 2]
395+
self.brand_his = inputs[:, self.history_length * 2:self.history_length
396+
* 3]
397+
self.mask = inputs[:, self.history_length * 3:self.history_length * 4]
398+
self.match_mask = inputs[:, self.history_length * 4:self.history_length
399+
* 5]
400+
401+
self.uid = inputs[:, self.history_length * 5]
402+
self.cms_segid = inputs[:, self.history_length * 5 + 1]
403+
self.cms_group_id = inputs[:, self.history_length * 5 + 2]
404+
self.final_gender_code = inputs[:, self.history_length * 5 + 3]
405+
self.age_level = inputs[:, self.history_length * 5 + 4]
406+
self.pvalue_level = inputs[:, self.history_length * 5 + 5]
407+
self.shopping_level = inputs[:, self.history_length * 5 + 6]
408+
self.occupation = inputs[:, self.history_length * 5 + 7]
409+
self.new_user_class_level = inputs[:, self.history_length * 5 + 8]
410+
411+
self.mid = inputs[:, self.history_length * 5 + 9]
412+
self.cate_id = inputs[:, self.history_length * 5 + 10]
413+
self.campaign_id = inputs[:, self.history_length * 5 + 11]
414+
self.customer = inputs[:, self.history_length * 5 + 12]
415+
self.brand = inputs[:, self.history_length * 5 + 13]
395416
self.price = dense_tensor.astype('float32')
396417

397-
self.pid = inputs[:, 265]
418+
self.pid = inputs[:, self.history_length * 5 + 15]
398419

399420
if is_infer == 0:
400-
self.labels = inputs[:, 266]
421+
self.labels = inputs[:, self.history_length * 5 + 16]
401422

402423
# embedding layer
403424
self.uid_batch_embedded = self.uid_embeddings_var(self.uid)

0 commit comments

Comments
 (0)