Skip to content

Commit 3edd804

Browse files
pramodithSunMarc
andauthored
Trainer: Pass num_items_in_batch to compute_loss in prediction_step (#41183)
* Add num_items_in_batch computation to predict_step. * address comments. * Fix test cases. * fixup --------- Co-authored-by: Marc Sun <[email protected]>
1 parent 59035fd commit 3edd804

File tree

2 files changed

+73
-14
lines changed

2 files changed

+73
-14
lines changed

src/transformers/trainer.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4844,7 +4844,10 @@ def prediction_step(
48444844
else:
48454845
if has_labels or loss_without_labels:
48464846
with self.compute_loss_context_manager():
4847-
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
4847+
num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device)
4848+
loss, outputs = self.compute_loss(
4849+
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
4850+
)
48484851
loss = loss.detach().mean()
48494852

48504853
if isinstance(outputs, dict):
@@ -5533,21 +5536,16 @@ def _fsdp_qlora_plugin_updates(self):
55335536
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
55345537
)
55355538

5536-
def get_batch_samples(
5537-
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
5538-
) -> tuple[list, Optional[Union[torch.Tensor, int]]]:
5539+
def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> int | None:
55395540
"""
5540-
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
5541+
Counts the number of items in the batches to properly scale the loss.
5542+
Args:
5543+
batch_samples (`list`): List of batches
5544+
device (`torch.device`): The device on which the number of items in the batch should be.
5545+
Returns:
5546+
None if the number of items in the batch doesn't need to be computed else the number of items in the batch
55415547
"""
5542-
batch_samples = []
55435548
num_items_in_batch = None
5544-
5545-
for _ in range(num_batches):
5546-
try:
5547-
batch_samples.append(next(epoch_iterator))
5548-
except StopIteration:
5549-
break
5550-
55515549
count_num_items_in_batch = (
55525550
len(batch_samples) > 0
55535551
and "labels" in batch_samples[0]
@@ -5562,7 +5560,6 @@ def get_batch_samples(
55625560
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
55635561
)
55645562
)
5565-
55665563
if count_num_items_in_batch:
55675564
# For now we don't support object detection
55685565
try:
@@ -5588,6 +5585,23 @@ def get_batch_samples(
55885585
if pc := getattr(self.accelerator, "parallelism_config", None):
55895586
num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size
55905587

5588+
return num_items_in_batch
5589+
5590+
def get_batch_samples(
5591+
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
5592+
) -> tuple[list, Optional[Union[torch.Tensor, int]]]:
5593+
"""
5594+
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
5595+
"""
5596+
batch_samples = []
5597+
5598+
for _ in range(num_batches):
5599+
try:
5600+
batch_samples.append(next(epoch_iterator))
5601+
except StopIteration:
5602+
break
5603+
5604+
num_items_in_batch = self._get_num_items_in_batch(batch_samples, device)
55915605
return batch_samples, num_items_in_batch
55925606

55935607
def set_initial_training_values(

tests/trainer/test_trainer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,6 +2872,9 @@ def test_evaluate_with_jit(self):
28722872
trainer = get_regression_trainer(
28732873
a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir
28742874
)
2875+
# Make sure the trainer doesn't pass num_items_in_batch to the model's forward method,
2876+
# since it's not in the model forward's signature when using JIT
2877+
trainer.model_accepts_loss_kwargs = False
28752878
results = trainer.evaluate()
28762879

28772880
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
@@ -2885,6 +2888,7 @@ def test_evaluate_with_jit(self):
28852888
trainer = get_regression_trainer(
28862889
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir
28872890
)
2891+
trainer.model_accepts_loss_kwargs = False
28882892
results = trainer.evaluate()
28892893

28902894
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
@@ -2903,6 +2907,7 @@ def test_evaluate_with_jit(self):
29032907
jit_mode_eval=True,
29042908
output_dir=tmp_dir,
29052909
)
2910+
trainer.model_accepts_loss_kwargs = False
29062911
results = trainer.evaluate()
29072912

29082913
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
@@ -2947,6 +2952,40 @@ def test_predict(self):
29472952
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
29482953
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
29492954

2955+
def test_train_and_predict_loss_parity(self):
2956+
"""
2957+
Tests that the loss computed during a training_step is the same as the one computed during prediction_step.
2958+
for the same inputs
2959+
"""
2960+
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
2961+
# Create a dummy batch of inputs
2962+
inputs = {}
2963+
inputs["input_ids"] = []
2964+
for row_ind in range(4):
2965+
seq_len = torch.randint(32, 64, (1,)).item()
2966+
x = torch.randint(1, 100, (seq_len,))
2967+
inputs["input_ids"].append(x)
2968+
inputs["input_ids"] = torch.nn.utils.rnn.pad_sequence(inputs["input_ids"], batch_first=True, padding_value=0)
2969+
inputs["labels"] = inputs["input_ids"].clone()
2970+
inputs["labels"][inputs["input_ids"] == 0] = -100
2971+
num_items_in_batch = inputs["labels"].ne(-100).sum().item()
2972+
2973+
def custom_loss_func(outputs, labels, num_items_in_batch=None):
2974+
logits = outputs["logits"]
2975+
loss_fct = torch.nn.CrossEntropyLoss()
2976+
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
2977+
if num_items_in_batch is not None:
2978+
return loss / num_items_in_batch # multiply by number of items to get the sum
2979+
return loss
2980+
2981+
trainer = Trainer(model, train_dataset=None, compute_loss_func=custom_loss_func)
2982+
2983+
# creating log history of trainer, results don't matter
2984+
train_loss = trainer.training_step(model, inputs, num_items_in_batch)
2985+
predict_loss = trainer.prediction_step(model, inputs, prediction_loss_only=True)[0]
2986+
2987+
torch.testing.assert_close(train_loss, predict_loss, atol=1e-6, rtol=0)
2988+
29502989
def test_predict_with_batch_eval_metrics(self):
29512990
with tempfile.TemporaryDirectory() as tmp_dir:
29522991
trainer = get_regression_trainer(
@@ -3014,18 +3053,23 @@ def test_predict_with_batch_eval_metrics(self):
30143053
def test_predict_with_jit(self):
30153054
with tempfile.TemporaryDirectory() as tmp_dir:
30163055
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True, output_dir=tmp_dir)
3056+
# Make sure the trainer doesn't pass num_items_in_batch to the model's forward method,
3057+
# since it's not in the model forward's signature when using JIT
3058+
trainer.model_accepts_loss_kwargs = False
30173059
preds = trainer.predict(trainer.eval_dataset).predictions
30183060
x = trainer.eval_dataset.x
30193061
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
30203062

30213063
# With a number of elements not a round multiple of the batch size
30223064
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True, output_dir=tmp_dir)
3065+
trainer.model_accepts_loss_kwargs = False
30233066
preds = trainer.predict(trainer.eval_dataset).predictions
30243067
x = trainer.eval_dataset.x
30253068
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
30263069

30273070
# With more than one output of the model
30283071
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True, output_dir=tmp_dir)
3072+
trainer.model_accepts_loss_kwargs = False
30293073
preds = trainer.predict(trainer.eval_dataset).predictions
30303074
x = trainer.eval_dataset.x
30313075
self.assertEqual(len(preds), 2)
@@ -3041,6 +3085,7 @@ def test_predict_with_jit(self):
30413085
jit_mode_eval=True,
30423086
output_dir=tmp_dir,
30433087
)
3088+
trainer.model_accepts_loss_kwargs = False
30443089
outputs = trainer.predict(trainer.eval_dataset)
30453090
preds = outputs.predictions
30463091
labels = outputs.label_ids

0 commit comments

Comments
 (0)