Skip to content

Commit 1eda99c

Browse files
authored
Add image unit tests (#603)
* Add image unit tests for timm Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Guard imports Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Adjust GPU marker Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Unmock inference in timm tests Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Add test for running inference with a fake model Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Add aesthetic and nsfw tests Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Add more aesthetic and nsfw tests Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Add image text pair dataset tests Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Guard gpu import Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Fix function name Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Fix bug in image text pair dataset Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Address Sarah's review Signed-off-by: Ryan Wolf <rywolf@nvidia.com> * Replace broken test Signed-off-by: Ryan Wolf <rywolf@nvidia.com> --------- Signed-off-by: Ryan Wolf <rywolf@nvidia.com>
1 parent 541f9b3 commit 1eda99c

File tree

7 files changed

+2772
-4
lines changed

7 files changed

+2772
-4
lines changed

nemo_curator/datasets/image_text_pair_dataset.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def filter_members(member):
149149
def _get_eligible_samples(self, output_path: str, samples_per_shard: int):
150150
parquet_glob_str = os.path.join(output_path, "temp_*.parquet")
151151
tar_glob_str = os.path.join(self.path, "*.tar")
152-
parquet_files = open_files(parquet_glob_str)
153-
tar_files = open_files(tar_glob_str)
152+
parquet_files = sorted(open_files(parquet_glob_str), key=lambda f: f.path)
153+
tar_files = sorted(open_files(tar_glob_str), key=lambda f: f.path)
154154

155155
curr_df = None
156156
total_tar_samples = []
@@ -198,8 +198,9 @@ def _get_eligible_samples(self, output_path: str, samples_per_shard: int):
198198
samples_per_shard * entries_per_sample :
199199
]
200200

201-
# Return the remaining df and samples
202-
yield curr_df, total_tar_samples
201+
# Return the remaining df and samples if it's not empty
202+
if len(curr_df) > 0:
203+
yield curr_df, total_tar_samples
203204

204205
@staticmethod
205206
def _combine_id(shard_id, sample_id, max_shards=5, max_samples_per_shard=4) -> str:

0 commit comments

Comments
 (0)