Skip to content

Commit e33610c

Browse files
Added check of the output names while removing cast nodes (#127)
Co-Authored-By: Davit <[email protected]>
1 parent e55f19f commit e33610c

File tree

1 file changed

+83
-26
lines changed

1 file changed

+83
-26
lines changed

modelopt/torch/utils/dataset_utils.py

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,88 @@ def _get_free_gpu_mem():
268268
return 64
269269

270270

271+
def _process_batch(batch_data, infer_method, max_working_batch_size=None):
272+
"""Process a batch of data through the model's inference method.
273+
274+
Args:
275+
batch_data: Dictionary containing the batch data
276+
infer_method: Model's inference method (either forward or generate)
277+
max_working_batch_size: Maximum batch size known to work without OOM
278+
279+
Returns:
280+
The maximum batch size that worked successfully
281+
"""
282+
assert all(torch.is_tensor(data) or data is None for data in batch_data.values()), (
283+
"batch_data values must be tensors"
284+
)
285+
# Get the batch size of current data
286+
batch_size = batch_data[list(batch_data.keys())[0]].shape[0]
287+
288+
# If we know a smaller batch size works, preemptively split
289+
if max_working_batch_size is not None and batch_size > max_working_batch_size:
290+
# Split the batch to avoid OOM
291+
for i in range(0, batch_size, max_working_batch_size):
292+
end_idx = min(i + max_working_batch_size, batch_size)
293+
split_data = {}
294+
for key in batch_data:
295+
if batch_data[key] is None:
296+
split_data[key] = None
297+
else:
298+
split_data[key] = batch_data[key][i:end_idx, ...]
299+
300+
max_working_batch_size = _process_batch(
301+
split_data, infer_method, max_working_batch_size
302+
)
303+
304+
return max_working_batch_size
305+
306+
# Try processing with current batch size
307+
try:
308+
infer_method(**batch_data)
309+
return (
310+
batch_size
311+
if max_working_batch_size is None
312+
else max(batch_size, max_working_batch_size)
313+
) # This batch size worked successfully
314+
except torch.cuda.OutOfMemoryError:
315+
assert batch_size > 1, (
316+
"CUDA out of memory error occurred while processing a single sample. "
317+
"This indicates the model is too large for the available GPU memory. "
318+
"Consider reducing the model size, using a smaller max_sample_length, "
319+
"or using a GPU with more memory."
320+
)
321+
322+
# Split the batch in half
323+
mid = (batch_size + 1) // 2
324+
warn(f"CUDA out of memory with batch size {batch_size}, trying with batch size {mid}")
325+
split_data_1 = {key: batch_data[key][:mid, ...] for key in batch_data}
326+
split_data_2 = {key: batch_data[key][mid:, ...] for key in batch_data}
327+
328+
# Recursively process each half and track max working batch size
329+
max_working_batch_size = _process_batch(split_data_1, infer_method)
330+
max_working_batch_size = _process_batch(split_data_2, infer_method, max_working_batch_size)
331+
332+
# Return the minimum of the two (to be conservative)
333+
return max_working_batch_size
334+
335+
336+
def _forward_loop(model: torch.nn.Module, dataloader: DataLoader) -> None:
337+
"""Runs forward passes through the model using data from the dataloader.
338+
339+
Args:
340+
model: The PyTorch model to run inference on
341+
dataloader: DataLoader containing the batched input data
342+
"""
343+
with torch.no_grad():
344+
is_enc_dec = model_type_is_enc_dec(model)
345+
infer_method = model.generate if is_enc_dec else model.forward
346+
max_working_batch_size = None # Initialize max working batch size as None
347+
348+
for _, data in enumerate(tqdm(dataloader)):
349+
# Process batch and update max working batch size
350+
max_working_batch_size = _process_batch(data, infer_method, max_working_batch_size)
351+
352+
271353
def create_forward_loop(
272354
model: Optional[torch.nn.Module] = None,
273355
dataset_name: str = "cnn_dailymail",
@@ -335,32 +417,7 @@ def create_forward_loop(
335417
include_labels=include_labels,
336418
)
337419

338-
def forward_loop(model):
339-
with torch.no_grad():
340-
low_mem_mode = False
341-
is_enc_dec = model_type_is_enc_dec(model)
342-
infer_method = model.generate if is_enc_dec else model.forward
343-
for _, data in enumerate(tqdm(dataloader)):
344-
batch_size = data[list(data.keys())[0]].shape[0]
345-
if batch_size == 1:
346-
infer_method(**data)
347-
elif not low_mem_mode:
348-
# Try running the forward once.
349-
# If output memory, we try running inference with split input tensors
350-
try:
351-
infer_method(**data)
352-
except torch.cuda.OutOfMemoryError:
353-
warn("torch.OutOfMemoryError detected, try reducing the batch size...")
354-
low_mem_mode = True
355-
356-
if low_mem_mode:
357-
split_data_1 = {key: data[key][: batch_size // 2, ...] for key in data}
358-
infer_method(**split_data_1)
359-
360-
split_data_2 = {key: data[key][batch_size // 2 :, ...] for key in data}
361-
infer_method(**split_data_2)
362-
363-
return forward_loop
420+
return lambda model: _forward_loop(model, dataloader)
364421

365422

366423
def model_type_is_enc_dec(model):

0 commit comments

Comments
 (0)