@@ -268,6 +268,88 @@ def _get_free_gpu_mem():
268268 return 64
269269
270270
271+ def _process_batch (batch_data , infer_method , max_working_batch_size = None ):
272+ """Process a batch of data through the model's inference method.
273+
274+ Args:
275+ batch_data: Dictionary containing the batch data
276+ infer_method: Model's inference method (either forward or generate)
277+ max_working_batch_size: Maximum batch size known to work without OOM
278+
279+ Returns:
280+ The maximum batch size that worked successfully
281+ """
282+ assert all (torch .is_tensor (data ) or data is None for data in batch_data .values ()), (
283+ "batch_data values must be tensors"
284+ )
285+ # Get the batch size of current data
286+ batch_size = batch_data [list (batch_data .keys ())[0 ]].shape [0 ]
287+
288+ # If we know a smaller batch size works, preemptively split
289+ if max_working_batch_size is not None and batch_size > max_working_batch_size :
290+ # Split the batch to avoid OOM
291+ for i in range (0 , batch_size , max_working_batch_size ):
292+ end_idx = min (i + max_working_batch_size , batch_size )
293+ split_data = {}
294+ for key in batch_data :
295+ if batch_data [key ] is None :
296+ split_data [key ] = None
297+ else :
298+ split_data [key ] = batch_data [key ][i :end_idx , ...]
299+
300+ max_working_batch_size = _process_batch (
301+ split_data , infer_method , max_working_batch_size
302+ )
303+
304+ return max_working_batch_size
305+
306+ # Try processing with current batch size
307+ try :
308+ infer_method (** batch_data )
309+ return (
310+ batch_size
311+ if max_working_batch_size is None
312+ else max (batch_size , max_working_batch_size )
313+ ) # This batch size worked successfully
314+ except torch .cuda .OutOfMemoryError :
315+ assert batch_size > 1 , (
316+ "CUDA out of memory error occurred while processing a single sample. "
317+ "This indicates the model is too large for the available GPU memory. "
318+ "Consider reducing the model size, using a smaller max_sample_length, "
319+ "or using a GPU with more memory."
320+ )
321+
322+ # Split the batch in half
323+ mid = (batch_size + 1 ) // 2
324+ warn (f"CUDA out of memory with batch size { batch_size } , trying with batch size { mid } " )
325+ split_data_1 = {key : batch_data [key ][:mid , ...] for key in batch_data }
326+ split_data_2 = {key : batch_data [key ][mid :, ...] for key in batch_data }
327+
328+ # Recursively process each half and track max working batch size
329+ max_working_batch_size = _process_batch (split_data_1 , infer_method )
330+ max_working_batch_size = _process_batch (split_data_2 , infer_method , max_working_batch_size )
331+
332+ # Return the minimum of the two (to be conservative)
333+ return max_working_batch_size
334+
335+
336+ def _forward_loop (model : torch .nn .Module , dataloader : DataLoader ) -> None :
337+ """Runs forward passes through the model using data from the dataloader.
338+
339+ Args:
340+ model: The PyTorch model to run inference on
341+ dataloader: DataLoader containing the batched input data
342+ """
343+ with torch .no_grad ():
344+ is_enc_dec = model_type_is_enc_dec (model )
345+ infer_method = model .generate if is_enc_dec else model .forward
346+ max_working_batch_size = None # Initialize max working batch size as None
347+
348+ for _ , data in enumerate (tqdm (dataloader )):
349+ # Process batch and update max working batch size
350+ max_working_batch_size = _process_batch (data , infer_method , max_working_batch_size )
351+
352+
271353def create_forward_loop (
272354 model : Optional [torch .nn .Module ] = None ,
273355 dataset_name : str = "cnn_dailymail" ,
@@ -335,32 +417,7 @@ def create_forward_loop(
335417 include_labels = include_labels ,
336418 )
337419
338- def forward_loop (model ):
339- with torch .no_grad ():
340- low_mem_mode = False
341- is_enc_dec = model_type_is_enc_dec (model )
342- infer_method = model .generate if is_enc_dec else model .forward
343- for _ , data in enumerate (tqdm (dataloader )):
344- batch_size = data [list (data .keys ())[0 ]].shape [0 ]
345- if batch_size == 1 :
346- infer_method (** data )
347- elif not low_mem_mode :
348- # Try running the forward once.
349- # If output memory, we try running inference with split input tensors
350- try :
351- infer_method (** data )
352- except torch .cuda .OutOfMemoryError :
353- warn ("torch.OutOfMemoryError detected, try reducing the batch size..." )
354- low_mem_mode = True
355-
356- if low_mem_mode :
357- split_data_1 = {key : data [key ][: batch_size // 2 , ...] for key in data }
358- infer_method (** split_data_1 )
359-
360- split_data_2 = {key : data [key ][batch_size // 2 :, ...] for key in data }
361- infer_method (** split_data_2 )
362-
363- return forward_loop
420+ return lambda model : _forward_loop (model , dataloader )
364421
365422
366423def model_type_is_enc_dec (model ):
0 commit comments