Skip to content

Commit 4b522e0

Browse files
authored
Do not modify num calib data samples to batch boundary (#483)
Signed-off-by: Chenjie Luo <[email protected]>
1 parent 2e19b5a commit 4b522e0

File tree

3 files changed

+0
-9
lines changed

3 files changed

+0
-9
lines changed

modelopt/torch/utils/dataset_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Utility functions for getting samples and forward loop function for different datasets."""
1717

1818
import copy
19-
import math
2019
from collections.abc import Callable
2120
from typing import TYPE_CHECKING, Any
2221
from warnings import warn
@@ -206,8 +205,6 @@ def get_dataset_dataloader(
206205
if isinstance(dataset_name, str):
207206
dataset_name = [dataset_name]
208207

209-
num_samples = [math.ceil(num_sample / batch_size) * batch_size for num_sample in num_samples]
210-
211208
assert len(dataset_name) == len(num_samples), (
212209
"dataset_name and num_samples must be the same length"
213210
)

modelopt/torch/utils/speech_dataset_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""Utility functions for getting samples and forward loop function for different speech datasets."""
1717

18-
import math
1918
from typing import Any
2019

2120
import torch
@@ -101,8 +100,6 @@ def get_speech_dataset_dataloader(
101100
"""
102101
assert processor is not None, "Please provide a valid processor."
103102

104-
num_samples = math.ceil(num_samples / batch_size) * batch_size
105-
106103
dataset = _get_speech_dataset(dataset_name, num_samples=num_samples)
107104
first_sample = next(iter(dataset))
108105
first_text = first_sample["text"]

modelopt/torch/utils/vlm_dataset_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""Utility functions for getting samples and forward loop function for different vlm datasets."""
1717

18-
import math
1918
from typing import Any
2019

2120
from torch.utils.data import DataLoader
@@ -93,8 +92,6 @@ def get_vlm_dataset_dataloader(
9392
"""
9493
assert processor is not None, "Please provide a valid processor."
9594

96-
num_samples = math.ceil(num_samples / batch_size) * batch_size
97-
9895
dataset = _get_vlm_dataset(dataset_name, num_samples=num_samples)
9996
# Apply the preprocessing function to the dataset
10097
processed_dataset = dataset.map(

0 commit comments

Comments
 (0)