@@ -268,6 +268,88 @@ def _get_free_gpu_mem():
268
268
return 64
269
269
270
270
271
+ def _process_batch (batch_data , infer_method , max_working_batch_size = None ):
272
+ """Process a batch of data through the model's inference method.
273
+
274
+ Args:
275
+ batch_data: Dictionary containing the batch data
276
+ infer_method: Model's inference method (either forward or generate)
277
+ max_working_batch_size: Maximum batch size known to work without OOM
278
+
279
+ Returns:
280
+ The maximum batch size that worked successfully
281
+ """
282
+ assert all (torch .is_tensor (data ) or data is None for data in batch_data .values ()), (
283
+ "batch_data values must be tensors"
284
+ )
285
+ # Get the batch size of current data
286
+ batch_size = batch_data [list (batch_data .keys ())[0 ]].shape [0 ]
287
+
288
+ # If we know a smaller batch size works, preemptively split
289
+ if max_working_batch_size is not None and batch_size > max_working_batch_size :
290
+ # Split the batch to avoid OOM
291
+ for i in range (0 , batch_size , max_working_batch_size ):
292
+ end_idx = min (i + max_working_batch_size , batch_size )
293
+ split_data = {}
294
+ for key in batch_data :
295
+ if batch_data [key ] is None :
296
+ split_data [key ] = None
297
+ else :
298
+ split_data [key ] = batch_data [key ][i :end_idx , ...]
299
+
300
+ max_working_batch_size = _process_batch (
301
+ split_data , infer_method , max_working_batch_size
302
+ )
303
+
304
+ return max_working_batch_size
305
+
306
+ # Try processing with current batch size
307
+ try :
308
+ infer_method (** batch_data )
309
+ return (
310
+ batch_size
311
+ if max_working_batch_size is None
312
+ else max (batch_size , max_working_batch_size )
313
+ ) # This batch size worked successfully
314
+ except torch .cuda .OutOfMemoryError :
315
+ assert batch_size > 1 , (
316
+ "CUDA out of memory error occurred while processing a single sample. "
317
+ "This indicates the model is too large for the available GPU memory. "
318
+ "Consider reducing the model size, using a smaller max_sample_length, "
319
+ "or using a GPU with more memory."
320
+ )
321
+
322
+ # Split the batch in half
323
+ mid = (batch_size + 1 ) // 2
324
+ warn (f"CUDA out of memory with batch size { batch_size } , trying with batch size { mid } " )
325
+ split_data_1 = {key : batch_data [key ][:mid , ...] for key in batch_data }
326
+ split_data_2 = {key : batch_data [key ][mid :, ...] for key in batch_data }
327
+
328
+ # Recursively process each half and track max working batch size
329
+ max_working_batch_size = _process_batch (split_data_1 , infer_method )
330
+ max_working_batch_size = _process_batch (split_data_2 , infer_method , max_working_batch_size )
331
+
332
+ # Return the minimum of the two (to be conservative)
333
+ return max_working_batch_size
334
+
335
+
336
+ def _forward_loop (model : torch .nn .Module , dataloader : DataLoader ) -> None :
337
+ """Runs forward passes through the model using data from the dataloader.
338
+
339
+ Args:
340
+ model: The PyTorch model to run inference on
341
+ dataloader: DataLoader containing the batched input data
342
+ """
343
+ with torch .no_grad ():
344
+ is_enc_dec = model_type_is_enc_dec (model )
345
+ infer_method = model .generate if is_enc_dec else model .forward
346
+ max_working_batch_size = None # Initialize max working batch size as None
347
+
348
+ for _ , data in enumerate (tqdm (dataloader )):
349
+ # Process batch and update max working batch size
350
+ max_working_batch_size = _process_batch (data , infer_method , max_working_batch_size )
351
+
352
+
271
353
def create_forward_loop (
272
354
model : Optional [torch .nn .Module ] = None ,
273
355
dataset_name : str = "cnn_dailymail" ,
@@ -335,32 +417,7 @@ def create_forward_loop(
335
417
include_labels = include_labels ,
336
418
)
337
419
338
- def forward_loop (model ):
339
- with torch .no_grad ():
340
- low_mem_mode = False
341
- is_enc_dec = model_type_is_enc_dec (model )
342
- infer_method = model .generate if is_enc_dec else model .forward
343
- for _ , data in enumerate (tqdm (dataloader )):
344
- batch_size = data [list (data .keys ())[0 ]].shape [0 ]
345
- if batch_size == 1 :
346
- infer_method (** data )
347
- elif not low_mem_mode :
348
- # Try running the forward once.
349
- # If output memory, we try running inference with split input tensors
350
- try :
351
- infer_method (** data )
352
- except torch .cuda .OutOfMemoryError :
353
- warn ("torch.OutOfMemoryError detected, try reducing the batch size..." )
354
- low_mem_mode = True
355
-
356
- if low_mem_mode :
357
- split_data_1 = {key : data [key ][: batch_size // 2 , ...] for key in data }
358
- infer_method (** split_data_1 )
359
-
360
- split_data_2 = {key : data [key ][batch_size // 2 :, ...] for key in data }
361
- infer_method (** split_data_2 )
362
-
363
- return forward_loop
420
+ return lambda model : _forward_loop (model , dataloader )
364
421
365
422
366
423
def model_type_is_enc_dec (model ):
0 commit comments