@@ -140,6 +140,7 @@ def dnn_model_define(user_input,
140140 fea_groups = "20,20,10,10,2,2,2,1,1,1" ,
141141 active_op = 'prelu' ,
142142 use_batch_norm = True ,
143+ with_att = False ,
143144 is_infer = False ,
144145 topk = 10 ):
145146 fea_groups = [int (s ) for s in fea_groups .split (',' )]
@@ -148,30 +149,37 @@ def dnn_model_define(user_input,
148149
149150 layer_data = []
150151 # start att
151- att_user_input = paddle .concat (
152- user_input , axis = 1 ) # [bs, total_group_length, emb_size]
153- att_node_input = fluid .layers .expand (
154- unit_id_emb , expand_times = [1 , total_group_length , 1 ])
155- att_din = paddle .concat (
156- [att_user_input , att_user_input * att_node_input , att_node_input ],
157- axis = 2 )
158-
159- att_active_op = 'prelu'
160- att_layer_arr = []
161- att_layer1 = FullyConnected3D (
162- 3 * node_emb_size , 36 , active_op = att_active_op , version = 1 )
163- att_layer_arr .append (att_layer1 )
164- att_layer2 = FullyConnected3D (36 , 1 , active_op = att_active_op , version = 2 )
165- att_layer_arr .append (att_layer2 )
166-
167- layer_data .append (att_din )
168- for layer in att_layer_arr :
169- layer_data .append (layer .call (layer_data [- 1 ]))
170- att_dout = layer_data [- 1 ]
171-
172- att_dout = fluid .layers .expand (
173- att_dout , expand_times = [1 , 1 , node_emb_size ])
174- user_input = att_user_input * att_dout
152+ if with_att :
153+ print ("TDM Attention DNN" )
154+ att_user_input = paddle .concat (
155+ user_input , axis = 1 ) # [bs, total_group_length, emb_size]
156+ att_node_input = fluid .layers .expand (
157+ unit_id_emb , expand_times = [1 , total_group_length , 1 ])
158+ att_din = paddle .concat (
159+ [att_user_input , att_user_input * att_node_input , att_node_input ],
160+ axis = 2 )
161+
162+ att_active_op = 'prelu'
163+ att_layer_arr = []
164+ att_layer1 = FullyConnected3D (
165+ 3 * node_emb_size , 36 , active_op = att_active_op , version = 1 )
166+ att_layer_arr .append (att_layer1 )
167+ att_layer2 = FullyConnected3D (
168+ 36 , 1 , active_op = att_active_op , version = 2 )
169+ att_layer_arr .append (att_layer2 )
170+
171+ layer_data .append (att_din )
172+ for layer in att_layer_arr :
173+ layer_data .append (layer .call (layer_data [- 1 ]))
174+ att_dout = layer_data [- 1 ]
175+
176+ att_dout = fluid .layers .expand (
177+ att_dout , expand_times = [1 , 1 , node_emb_size ])
178+ user_input = att_user_input * att_dout
179+ else :
180+ print ("TDM DNN" )
181+ user_input = paddle .concat (
182+ user_input , axis = 1 ) # [bs, total_group_length, emb_size]
175183 # end att
176184
177185 idx = 0
@@ -207,13 +215,13 @@ def dnn_model_define(user_input,
207215 layer_arr .append (layer2 )
208216 layer3 = paddle_dnn_layer (
209217 64 ,
210- 32 ,
218+ 24 ,
211219 active_op = active_op ,
212220 use_batch_norm = use_batch_norm ,
213221 version = "%d_%s" % (3 , net_version ))
214222 layer_arr .append (layer3 )
215223 layer4 = paddle_dnn_layer (
216- 32 ,
224+ 24 ,
217225 2 ,
218226 active_op = '' ,
219227 use_batch_norm = False ,
0 commit comments