Skip to content

Commit f3c4d2a

Browse files
authored
fix of clm performance (#723)
1 parent d9f67f2 commit f3c4d2a

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

tests/unit/torch/features/test_sequential.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def test_sequential_tabular_features_ignore_masking(schema, torch_yoochoose_like
136136
input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy()
137137
)
138138

139-
assert np.allclose(output_wo_masking, output_inference_masking, rtol=1e-04, atol=1e-08)
140139
assert not np.allclose(output_wo_masking, output_clm_masking, rtol=1e-04, atol=1e-08)
141140

142141
input_module._masking = MaskedLanguageModeling(hidden_size=100)

tests/unit/torch/test_masking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task):
4343
lm = tr.masking.masking_registry[task](
4444
hidden_dim, padding_idx=torch_masking_inputs["padding_idx"]
4545
)
46-
lm.compute_masked_targets(torch_masking_inputs["labels"], training=False)
46+
lm.compute_masked_targets(torch_masking_inputs["labels"], training=False, testing=True)
4747
# get non padded last items
4848
non_padded_mask = torch_masking_inputs["labels"] != torch_masking_inputs["padding_idx"]
4949
rows_ids = torch.arange(
@@ -57,7 +57,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task):
5757
trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"]
5858
out_last = lm.masked_targets[trgt_pad].flatten().numpy()
5959
# check that only one item is masked for each session
60-
assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0)
60+
assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0)
6161
# check only the last non-paded item is masked
6262
assert all(last_labels == out_last)
6363

@@ -109,7 +109,7 @@ def test_clm_training_on_last_item(torch_masking_inputs):
109109
# last labels from output
110110
trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"]
111111
out_last = lm.masked_targets[trgt_pad].flatten().numpy()
112-
assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0)
112+
assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0)
113113
assert all(last_labels == out_last)
114114

115115

transformers4rec/torch/masking.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def __init__(
274274
def _compute_masked_targets(
275275
self, item_ids: torch.Tensor, training: bool = False, testing: bool = False
276276
) -> MaskingInfo:
277+
if not training and not testing:
278+
mask_labels = item_ids != self.padding_idx
279+
return MaskingInfo(mask_labels, item_ids)
280+
277281
masking_info = self.predict_all(item_ids)
278282
mask_labels, labels = masking_info.schema, masking_info.targets
279283

@@ -290,7 +294,8 @@ def _compute_masked_targets(
290294
label_seq_trg_eval[rows_ids, last_item_sessions] = labels[rows_ids, last_item_sessions]
291295
# Updating labels and mask
292296
labels = label_seq_trg_eval
293-
mask_labels = label_seq_trg_eval != self.padding_idx
297+
# We only mask padded positions
298+
mask_labels = item_ids != self.padding_idx
294299

295300
return MaskingInfo(mask_labels, labels)
296301

@@ -302,6 +307,13 @@ def apply_mask_to_inputs(
302307
testing: bool = False,
303308
) -> torch.Tensor:
304309
if not training and not testing:
310+
# Replacing the inputs corresponding to padded items with a trainable embedding
311+
# To mimic training and evaluation masking strategy
312+
inputs = torch.where(
313+
mask_schema.unsqueeze(-1).bool(),
314+
inputs,
315+
self.masked_item_embedding.to(inputs.dtype),
316+
)
305317
return inputs
306318
# shift sequence of interaction embeddings
307319
pos_emb_inp = inputs[:, :-1]
@@ -316,7 +328,7 @@ def apply_mask_to_inputs(
316328
],
317329
axis=1,
318330
)
319-
# Replacing the inputs corresponding to masked label with a trainable embedding
331+
# Replacing the inputs corresponding to padded items with a trainable embedding
320332
pos_emb_inp = torch.where(
321333
mask_schema.unsqueeze(-1).bool(),
322334
pos_emb_inp,
@@ -601,14 +613,16 @@ def _compute_masked_targets_extended(
601613
# from the interval `[cur_len, cur_len + context_length - span_length]`
602614
start_index = (
603615
cur_len
604-
+ torch.randint( # type: ignore
605-
context_length - span_length + 1, (1,)
616+
+ torch.randint(
617+
context_length - span_length + 1, (1,) # type: ignore
606618
).item()
607619
)
608620
if start_index < max_len:
609621
# Mask the span of non-padded items
610622
# `start_index:start_index + span_length`
611-
mask_labels[i, start_index : start_index + span_length] = 1
623+
mask_labels[
624+
i, start_index : start_index + span_length # type: ignore
625+
] = 1
612626
# Set `cur_len = cur_len + context_length`
613627
cur_len += context_length
614628
# if no item was masked:

0 commit comments

Comments
 (0)