2121class Mind_SampledSoftmaxLoss_Layer (nn .Layer ):
2222 """SampledSoftmaxLoss with LogUniformSampler
2323 """
24+
2425 def __init__ (self ,
2526 num_classes ,
2627 n_sample ,
@@ -83,13 +84,13 @@ def forward(self, inputs, labels, weights, bias):
8384 sample_b = all_b [- n_sample :]
8485
8586 # [B, D] * [B, 1,D]
86- true_logist = paddle .sum (paddle .multiply (
87- true_w , inputs . unsqueeze ( 1 )), axis = - 1 ) + true_b
87+ true_logist = paddle .sum (paddle .multiply (true_w , inputs . unsqueeze ( 1 )),
88+ axis = - 1 ) + true_b
8889 # print(true_logist)
89-
90+
9091 sample_logist = paddle .matmul (
91- inputs , sample_w , transpose_y = True ) + sample_b
92-
92+ inputs , sample_w , transpose_y = True ) + sample_b
93+
9394 if self .remove_accidental_hits :
9495 hit = (paddle .equal (labels [:, :], neg_samples ))
9596 padding = paddle .ones_like (sample_logist ) * - 1e30
@@ -115,6 +116,7 @@ def forward(self, inputs, labels, weights, bias):
115116class Mind_Capsual_Layer (nn .Layer ):
116117 """Mind_Capsual_Layer
117118 """
119+
118120 def __init__ (self ,
119121 input_units ,
120122 output_units ,
@@ -189,11 +191,13 @@ def forward(self, item_his_emb, seq_len):
189191
190192 low_capsule_new_tile = paddle .tile (low_capsule_new , [1 , 1 , self .k_max ])
191193 low_capsule_new_tile = paddle .reshape (
192- low_capsule_new_tile , [- 1 , self .maxlen , self .k_max , self .output_units ])
193- low_capsule_new_tile = paddle .transpose (
194- low_capsule_new_tile , [0 , 2 , 1 , 3 ])
194+ low_capsule_new_tile ,
195+ [- 1 , self .maxlen , self .k_max , self .output_units ])
196+ low_capsule_new_tile = paddle .transpose (low_capsule_new_tile ,
197+ [0 , 2 , 1 , 3 ])
195198 low_capsule_new_tile = paddle .reshape (
196- low_capsule_new_tile , [- 1 , self .k_max , self .maxlen , self .output_units ])
199+ low_capsule_new_tile ,
200+ [- 1 , self .k_max , self .maxlen , self .output_units ])
197201 low_capsule_new_nograd = paddle .assign (low_capsule_new_tile )
198202 low_capsule_new_nograd .stop_gradient = True
199203
@@ -209,8 +213,9 @@ def forward(self, item_his_emb, seq_len):
209213 high_capsule_tmp = paddle .matmul (W , low_capsule_new_nograd )
210214 # print(low_capsule_new_nograd.shape)
211215 high_capsule = self .squash (high_capsule_tmp )
212- B_delta = paddle .matmul (low_capsule_new_nograd ,
213- paddle .transpose (high_capsule , [0 , 1 , 3 , 2 ]))
216+ B_delta = paddle .matmul (
217+ low_capsule_new_nograd ,
218+ paddle .transpose (high_capsule , [0 , 1 , 3 , 2 ]))
214219 B_delta = paddle .reshape (
215220 B_delta , shape = [- 1 , self .k_max , self .maxlen ])
216221 B += B_delta
@@ -220,8 +225,8 @@ def forward(self, item_his_emb, seq_len):
220225 W = paddle .unsqueeze (W , axis = 2 )
221226 interest_capsule = paddle .matmul (W , low_capsule_new_tile )
222227 interest_capsule = self .squash (interest_capsule )
223- high_capsule = paddle .reshape (
224- interest_capsule , [- 1 , self .k_max , self .output_units ])
228+ high_capsule = paddle .reshape (interest_capsule ,
229+ [- 1 , self .k_max , self .output_units ])
225230
226231 high_capsule = F .relu (self .relu_layer (high_capsule ))
227232 return high_capsule , W , seq_len
@@ -277,12 +282,16 @@ def __init__(self,
277282 def label_aware_attention (self , keys , query ):
278283 """label_aware_attention
279284 """
280- weight = paddle .matmul (keys , paddle .reshape (query , [- 1 , paddle .shape (query )[- 1 ], 1 ])) #[B, K, dim] * [B, dim, 1] == [B, k, 1]
285+ weight = paddle .matmul (keys ,
286+ paddle .reshape (query , [
287+ - 1 , paddle .shape (query )[- 1 ], 1
288+ ])) #[B, K, dim] * [B, dim, 1] == [B, k, 1]
281289 weight = paddle .squeeze (weight , axis = - 1 )
282290 weight = paddle .pow (weight , self .pow_p ) # [x,k_max]
283- weight = F .softmax (weight ) #[x, k_max]
284- weight = paddle .unsqueeze (weight , 1 ) #[B, 1, k_max]
285- output = paddle .matmul (weight , keys ) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim]
291+ weight = F .softmax (weight ) #[x, k_max]
292+ weight = paddle .unsqueeze (weight , 1 ) #[B, 1, k_max]
293+ output = paddle .matmul (
294+ weight , keys ) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim]
286295 return output .squeeze (1 ), weight
287296
288297 def forward (self , hist_item , seqlen , labels = None ):
0 commit comments