@@ -2872,6 +2872,9 @@ def test_evaluate_with_jit(self):
28722872 trainer = get_regression_trainer (
28732873 a = 1.5 , b = 2.5 , compute_metrics = AlmostAccuracy (), jit_mode_eval = True , output_dir = tmp_dir
28742874 )
2875+ # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method,
2876+ # since it's not in the model forward's signature when using JIT
2877+ trainer .model_accepts_loss_kwargs = False
28752878 results = trainer .evaluate ()
28762879
28772880 x , y = trainer .eval_dataset .x , trainer .eval_dataset .ys [0 ]
@@ -2885,6 +2888,7 @@ def test_evaluate_with_jit(self):
28852888 trainer = get_regression_trainer (
28862889 a = 1.5 , b = 2.5 , eval_len = 66 , compute_metrics = AlmostAccuracy (), jit_mode_eval = True , output_dir = tmp_dir
28872890 )
2891+ trainer .model_accepts_loss_kwargs = False
28882892 results = trainer .evaluate ()
28892893
28902894 x , y = trainer .eval_dataset .x , trainer .eval_dataset .ys [0 ]
@@ -2903,6 +2907,7 @@ def test_evaluate_with_jit(self):
29032907 jit_mode_eval = True ,
29042908 output_dir = tmp_dir ,
29052909 )
2910+ trainer .model_accepts_loss_kwargs = False
29062911 results = trainer .evaluate ()
29072912
29082913 x , y = trainer .eval_dataset .x , trainer .eval_dataset .ys [0 ]
@@ -2947,6 +2952,40 @@ def test_predict(self):
29472952 self .assertTrue (np .array_equal (labels [0 ], trainer .eval_dataset .ys [0 ]))
29482953 self .assertTrue (np .array_equal (labels [1 ], trainer .eval_dataset .ys [1 ]))
29492954
2955+ def test_train_and_predict_loss_parity (self ):
2956+ """
2957+ Tests that the loss computed during a training_step is the same as the one computed during prediction_step.
2958+ for the same inputs
2959+ """
2960+ model = AutoModelForCausalLM .from_pretrained ("hf-internal-testing/tiny-random-LlamaForCausalLM" )
2961+ # Create a dummy batch of inputs
2962+ inputs = {}
2963+ inputs ["input_ids" ] = []
2964+ for row_ind in range (4 ):
2965+ seq_len = torch .randint (32 , 64 , (1 ,)).item ()
2966+ x = torch .randint (1 , 100 , (seq_len ,))
2967+ inputs ["input_ids" ].append (x )
2968+ inputs ["input_ids" ] = torch .nn .utils .rnn .pad_sequence (inputs ["input_ids" ], batch_first = True , padding_value = 0 )
2969+ inputs ["labels" ] = inputs ["input_ids" ].clone ()
2970+ inputs ["labels" ][inputs ["input_ids" ] == 0 ] = - 100
2971+ num_items_in_batch = inputs ["labels" ].ne (- 100 ).sum ().item ()
2972+
2973+ def custom_loss_func (outputs , labels , num_items_in_batch = None ):
2974+ logits = outputs ["logits" ]
2975+ loss_fct = torch .nn .CrossEntropyLoss ()
2976+ loss = loss_fct (logits .view (- 1 , logits .size (- 1 )), labels .view (- 1 ))
2977+ if num_items_in_batch is not None :
2978+ return loss / num_items_in_batch # multiply by number of items to get the sum
2979+ return loss
2980+
2981+ trainer = Trainer (model , train_dataset = None , compute_loss_func = custom_loss_func )
2982+
2983+ # creating log history of trainer, results don't matter
2984+ train_loss = trainer .training_step (model , inputs , num_items_in_batch )
2985+ predict_loss = trainer .prediction_step (model , inputs , prediction_loss_only = True )[0 ]
2986+
2987+ torch .testing .assert_close (train_loss , predict_loss , atol = 1e-6 , rtol = 0 )
2988+
29502989 def test_predict_with_batch_eval_metrics (self ):
29512990 with tempfile .TemporaryDirectory () as tmp_dir :
29522991 trainer = get_regression_trainer (
@@ -3014,18 +3053,23 @@ def test_predict_with_batch_eval_metrics(self):
30143053 def test_predict_with_jit (self ):
30153054 with tempfile .TemporaryDirectory () as tmp_dir :
30163055 trainer = get_regression_trainer (a = 1.5 , b = 2.5 , jit_mode_eval = True , output_dir = tmp_dir )
3056+ # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method,
3057+ # since it's not in the model forward's signature when using JIT
3058+ trainer .model_accepts_loss_kwargs = False
30173059 preds = trainer .predict (trainer .eval_dataset ).predictions
30183060 x = trainer .eval_dataset .x
30193061 self .assertTrue (np .allclose (preds , 1.5 * x + 2.5 ))
30203062
30213063 # With a number of elements not a round multiple of the batch size
30223064 trainer = get_regression_trainer (a = 1.5 , b = 2.5 , eval_len = 66 , jit_mode_eval = True , output_dir = tmp_dir )
3065+ trainer .model_accepts_loss_kwargs = False
30233066 preds = trainer .predict (trainer .eval_dataset ).predictions
30243067 x = trainer .eval_dataset .x
30253068 self .assertTrue (np .allclose (preds , 1.5 * x + 2.5 ))
30263069
30273070 # With more than one output of the model
30283071 trainer = get_regression_trainer (a = 1.5 , b = 2.5 , double_output = True , jit_mode_eval = True , output_dir = tmp_dir )
3072+ trainer .model_accepts_loss_kwargs = False
30293073 preds = trainer .predict (trainer .eval_dataset ).predictions
30303074 x = trainer .eval_dataset .x
30313075 self .assertEqual (len (preds ), 2 )
@@ -3041,6 +3085,7 @@ def test_predict_with_jit(self):
30413085 jit_mode_eval = True ,
30423086 output_dir = tmp_dir ,
30433087 )
3088+ trainer .model_accepts_loss_kwargs = False
30443089 outputs = trainer .predict (trainer .eval_dataset )
30453090 preds = outputs .predictions
30463091 labels = outputs .label_ids
0 commit comments