@@ -274,6 +274,10 @@ def __init__(
274274 def _compute_masked_targets (
275275 self , item_ids : torch .Tensor , training : bool = False , testing : bool = False
276276 ) -> MaskingInfo :
277+ if not training and not testing :
278+ mask_labels = item_ids != self .padding_idx
279+ return MaskingInfo (mask_labels , item_ids )
280+
277281 masking_info = self .predict_all (item_ids )
278282 mask_labels , labels = masking_info .schema , masking_info .targets
279283
@@ -290,7 +294,8 @@ def _compute_masked_targets(
290294 label_seq_trg_eval [rows_ids , last_item_sessions ] = labels [rows_ids , last_item_sessions ]
291295 # Updating labels and mask
292296 labels = label_seq_trg_eval
293- mask_labels = label_seq_trg_eval != self .padding_idx
297+ # We only mask padded positions
298+ mask_labels = item_ids != self .padding_idx
294299
295300 return MaskingInfo (mask_labels , labels )
296301
@@ -302,6 +307,13 @@ def apply_mask_to_inputs(
302307 testing : bool = False ,
303308 ) -> torch .Tensor :
304309 if not training and not testing :
310+ # Replacing the inputs corresponding to padded items with a trainable embedding
311+ # To mimic training and evaluation masking strategy
312+ inputs = torch .where (
313+ mask_schema .unsqueeze (- 1 ).bool (),
314+ inputs ,
315+ self .masked_item_embedding .to (inputs .dtype ),
316+ )
305317 return inputs
306318 # shift sequence of interaction embeddings
307319 pos_emb_inp = inputs [:, :- 1 ]
@@ -316,7 +328,7 @@ def apply_mask_to_inputs(
316328 ],
317329 axis = 1 ,
318330 )
319- # Replacing the inputs corresponding to masked label with a trainable embedding
331+ # Replacing the inputs corresponding to padded items with a trainable embedding
320332 pos_emb_inp = torch .where (
321333 mask_schema .unsqueeze (- 1 ).bool (),
322334 pos_emb_inp ,
@@ -601,14 +613,16 @@ def _compute_masked_targets_extended(
601613 # from the interval `[cur_len, cur_len + context_length - span_length]`
602614 start_index = (
603615 cur_len
604- + torch .randint ( # type: ignore
605- context_length - span_length + 1 , (1 ,)
616+ + torch .randint (
617+ context_length - span_length + 1 , (1 ,) # type: ignore
606618 ).item ()
607619 )
608620 if start_index < max_len :
609621 # Mask the span of non-padded items
610622 # `start_index:start_index + span_length`
611- mask_labels [i , start_index : start_index + span_length ] = 1
623+ mask_labels [
624+ i , start_index : start_index + span_length # type: ignore
625+ ] = 1
612626 # Set `cur_len = cur_len + context_length`
613627 cur_len += context_length
614628 # if no item was masked:
0 commit comments