Skip to content

Commit c8c1e0b

Browse files
authored
Add ArgumentHelper tests (#595)
Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
1 parent d6fcbdb commit c8c1e0b

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

tests/test_argument_helper.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
from nemo_curator.utils.script_utils import ArgumentHelper
18+
19+
20+
class TestArgumentHelper:
21+
def test_argument_helper(self):
22+
parser = argparse.ArgumentParser()
23+
argument_helper = ArgumentHelper(parser)
24+
assert "--version" in argument_helper.parser.format_help()
25+
26+
def test_attach_bool_arg(self):
27+
parser = argparse.ArgumentParser()
28+
argument_helper = ArgumentHelper(parser)
29+
argument_helper.attach_bool_arg(
30+
argument_helper.parser, "test", default=False, help="test help"
31+
)
32+
assert "--test" in parser.format_help()
33+
assert "test help" in parser.format_help()
34+
assert "--no-test" in parser.format_help()
35+
36+
def test_add_arg_batch_size(self):
37+
parser = argparse.ArgumentParser()
38+
argument_helper = ArgumentHelper(parser)
39+
argument_helper.add_arg_batch_size()
40+
assert "--batch-size" in argument_helper.parser.format_help()
41+
42+
def test_add_arg_device(self):
43+
parser = argparse.ArgumentParser()
44+
argument_helper = ArgumentHelper(parser)
45+
argument_helper.add_arg_device()
46+
assert "--device" in argument_helper.parser.format_help()
47+
48+
def test_add_arg_enable_spilling(self):
49+
parser = argparse.ArgumentParser()
50+
argument_helper = ArgumentHelper(parser)
51+
argument_helper.add_arg_enable_spilling()
52+
assert "--enable-spilling" in argument_helper.parser.format_help()
53+
54+
def test_add_arg_language(self):
55+
parser = argparse.ArgumentParser()
56+
argument_helper = ArgumentHelper(parser)
57+
help = "The language of the dataset to be read in."
58+
argument_helper.add_arg_language(help=help)
59+
assert "--language" in argument_helper.parser.format_help()
60+
assert help in argument_helper.parser.format_help()
61+
62+
def test_add_arg_log_dir(self):
63+
parser = argparse.ArgumentParser()
64+
argument_helper = ArgumentHelper(parser)
65+
default = "./log"
66+
argument_helper.add_arg_log_dir(default=default)
67+
assert "--log-dir" in argument_helper.parser.format_help()
68+
69+
def test_add_arg_input_data_dir(self):
70+
parser = argparse.ArgumentParser()
71+
argument_helper = ArgumentHelper(parser)
72+
argument_helper.add_arg_input_data_dir()
73+
assert "--input-data-dir" in argument_helper.parser.format_help()
74+
75+
def test_add_arg_input_file_type(self):
76+
parser = argparse.ArgumentParser()
77+
argument_helper = ArgumentHelper(parser)
78+
choices = ["jsonl", "pickle", "parquet"]
79+
argument_helper.add_arg_input_file_type(choices=choices)
80+
assert "--input-file-type" in argument_helper.parser.format_help()
81+
assert "jsonl" in argument_helper.parser.format_help()
82+
assert "pickle" in argument_helper.parser.format_help()
83+
assert "parquet" in argument_helper.parser.format_help()
84+
85+
def test_add_arg_input_file_extension(self):
86+
parser = argparse.ArgumentParser()
87+
argument_helper = ArgumentHelper(parser)
88+
argument_helper.add_arg_input_file_extension()
89+
assert "--input-file-extension" in argument_helper.parser.format_help()
90+
91+
def test_add_arg_input_local_data_dir(self):
92+
parser = argparse.ArgumentParser()
93+
argument_helper = ArgumentHelper(parser)
94+
argument_helper.add_arg_input_local_data_dir()
95+
assert "--input-local-data-dir" in argument_helper.parser.format_help()
96+
97+
def test_add_arg_input_meta(self):
98+
parser = argparse.ArgumentParser()
99+
argument_helper = ArgumentHelper(parser)
100+
argument_helper.add_arg_input_meta()
101+
assert "--input-meta" in argument_helper.parser.format_help()
102+
103+
def test_add_arg_input_text_field(self):
104+
parser = argparse.ArgumentParser()
105+
argument_helper = ArgumentHelper(parser)
106+
argument_helper.add_arg_input_text_field()
107+
assert "--input-text-field" in argument_helper.parser.format_help()
108+
109+
def test_add_arg_id_column(self):
110+
parser = argparse.ArgumentParser()
111+
argument_helper = ArgumentHelper(parser)
112+
argument_helper.add_arg_id_column()
113+
assert "--id-column" in argument_helper.parser.format_help()
114+
115+
def test_add_arg_id_column_type(self):
116+
parser = argparse.ArgumentParser()
117+
argument_helper = ArgumentHelper(parser)
118+
argument_helper.add_arg_id_column_type()
119+
assert "--id-column-type" in argument_helper.parser.format_help()
120+
121+
def test_add_arg_minhash_length(self):
122+
parser = argparse.ArgumentParser()
123+
argument_helper = ArgumentHelper(parser)
124+
argument_helper.add_arg_minhash_length()
125+
assert "--minhash-length" in argument_helper.parser.format_help()
126+
127+
def test_add_arg_nvlink_only(self):
128+
parser = argparse.ArgumentParser()
129+
argument_helper = ArgumentHelper(parser)
130+
argument_helper.add_arg_nvlink_only()
131+
assert "--nvlink-only" in argument_helper.parser.format_help()
132+
133+
def test_add_arg_output_data_dir(self):
134+
parser = argparse.ArgumentParser()
135+
argument_helper = ArgumentHelper(parser)
136+
help = "Output data directory."
137+
argument_helper.add_arg_output_data_dir(help=help)
138+
assert "--output-data-dir" in argument_helper.parser.format_help()
139+
assert help in argument_helper.parser.format_help()
140+
141+
def test_add_arg_output_dir(self):
142+
parser = argparse.ArgumentParser()
143+
argument_helper = ArgumentHelper(parser)
144+
argument_helper.add_arg_output_dir()
145+
assert "--output-dir" in argument_helper.parser.format_help()
146+
147+
def test_add_arg_output_file_type(self):
148+
parser = argparse.ArgumentParser()
149+
argument_helper = ArgumentHelper(parser)
150+
choices = ["jsonl", "pickle", "parquet"]
151+
argument_helper.add_arg_output_file_type(choices=choices)
152+
assert "--output-file-type" in argument_helper.parser.format_help()
153+
assert "jsonl" in argument_helper.parser.format_help()
154+
assert "pickle" in argument_helper.parser.format_help()
155+
assert "parquet" in argument_helper.parser.format_help()
156+
157+
def test_add_arg_output_train_file(self):
158+
parser = argparse.ArgumentParser()
159+
argument_helper = ArgumentHelper(parser)
160+
help = "The output train file."
161+
argument_helper.add_arg_output_train_file(help=help)
162+
assert "--output-train-file" in argument_helper.parser.format_help()
163+
assert help in argument_helper.parser.format_help()
164+
165+
def test_add_arg_protocol(self):
166+
parser = argparse.ArgumentParser()
167+
argument_helper = ArgumentHelper(parser)
168+
argument_helper.add_arg_protocol()
169+
assert "--protocol" in argument_helper.parser.format_help()
170+
171+
def test_add_arg_rmm_pool_size(self):
172+
parser = argparse.ArgumentParser()
173+
argument_helper = ArgumentHelper(parser)
174+
argument_helper.add_arg_rmm_pool_size()
175+
assert "--rmm-pool-size" in argument_helper.parser.format_help()
176+
177+
def test_add_arg_scheduler_address(self):
178+
parser = argparse.ArgumentParser()
179+
argument_helper = ArgumentHelper(parser)
180+
argument_helper.add_arg_scheduler_address()
181+
assert "--scheduler-address" in argument_helper.parser.format_help()
182+
183+
def test_add_arg_scheduler_file(self):
184+
parser = argparse.ArgumentParser()
185+
argument_helper = ArgumentHelper(parser)
186+
argument_helper.add_arg_scheduler_file()
187+
assert "--scheduler-file" in argument_helper.parser.format_help()
188+
189+
def test_add_arg_seed(self):
190+
parser = argparse.ArgumentParser()
191+
argument_helper = ArgumentHelper(parser)
192+
argument_helper.add_arg_seed()
193+
assert "--seed" in argument_helper.parser.format_help()
194+
195+
def test_add_arg_set_torch_to_use_rmm(self):
196+
parser = argparse.ArgumentParser()
197+
argument_helper = ArgumentHelper(parser)
198+
argument_helper.add_arg_set_torch_to_use_rmm()
199+
assert "--set-torch-to-use-rmm" in argument_helper.parser.format_help()
200+
201+
def test_add_arg_shuffle(self):
202+
parser = argparse.ArgumentParser()
203+
argument_helper = ArgumentHelper(parser)
204+
help = "Shuffle argument help"
205+
argument_helper.add_arg_shuffle(help=help)
206+
assert "--shuffle" in argument_helper.parser.format_help()
207+
assert help in argument_helper.parser.format_help()
208+
209+
def test_add_arg_text_ddf_blocksize(self):
210+
parser = argparse.ArgumentParser()
211+
argument_helper = ArgumentHelper(parser)
212+
argument_helper.add_arg_text_ddf_blocksize()
213+
assert "--text-ddf-blocksize" in argument_helper.parser.format_help()
214+
215+
def test_add_arg_model_path(self):
216+
parser = argparse.ArgumentParser()
217+
argument_helper = ArgumentHelper(parser)
218+
argument_helper.add_arg_model_path()
219+
assert "--pretrained-model-name-or-path" in argument_helper.parser.format_help()
220+
221+
def test_add_arg_max_mem_gb_classifier(self):
222+
parser = argparse.ArgumentParser()
223+
argument_helper = ArgumentHelper(parser)
224+
argument_helper.add_arg_max_mem_gb_classifier()
225+
assert "--max-mem-gb-classifier" in argument_helper.parser.format_help()
226+
227+
def test_add_arg_autocast(self):
228+
parser = argparse.ArgumentParser()
229+
argument_helper = ArgumentHelper(parser)
230+
argument_helper.add_arg_autocast()
231+
assert "--autocast" in argument_helper.parser.format_help()
232+
233+
def test_add_arg_max_chars(self):
234+
parser = argparse.ArgumentParser()
235+
argument_helper = ArgumentHelper(parser)
236+
argument_helper.add_arg_max_chars()
237+
assert "--max-chars" in argument_helper.parser.format_help()
238+
239+
def test_distributed_args(self):
240+
parser = argparse.ArgumentParser()
241+
argument_helper = ArgumentHelper(parser)
242+
argument_helper.add_distributed_args()
243+
244+
assert "--device" in argument_helper.parser.format_help()
245+
assert "--files-per-partition" in argument_helper.parser.format_help()
246+
assert "--n-workers" in argument_helper.parser.format_help()
247+
assert "--num-files" in argument_helper.parser.format_help()
248+
assert "--nvlink-only" in argument_helper.parser.format_help()
249+
assert "--protocol" in argument_helper.parser.format_help()
250+
assert "--rmm-pool-size" in argument_helper.parser.format_help()
251+
assert "--scheduler-address" in argument_helper.parser.format_help()
252+
assert "--scheduler-file" in argument_helper.parser.format_help()
253+
assert "--threads-per-worker" in argument_helper.parser.format_help()
254+
255+
def test_set_default_n_workers(self):
256+
parser = argparse.ArgumentParser()
257+
argument_helper = ArgumentHelper(parser)
258+
max_mem_gb_per_worker = 10.0
259+
argument_helper.set_default_n_workers(max_mem_gb_per_worker)
260+
assert "n_workers" in argument_helper.parser._defaults
261+
262+
def test_parse_client_args(self):
263+
parser = argparse.ArgumentParser()
264+
argument_helper = ArgumentHelper(parser)
265+
args = argparse.Namespace(device="gpu", n_workers=10, random_arg="abc")
266+
parsed_args = argument_helper.parse_client_args(args)
267+
assert parsed_args["cluster_type"] == "gpu"
268+
assert parsed_args["n_workers"] == 10
269+
assert "random_arg" not in parsed_args
270+
271+
def test_parse_distributed_classifier_args(self):
272+
parser = argparse.ArgumentParser()
273+
argument_helper = ArgumentHelper(parser)
274+
parser = argument_helper.parse_distributed_classifier_args()
275+
276+
assert "--device" in parser.format_help()
277+
assert "--files-per-partition" in parser.format_help()
278+
assert "--n-workers" in parser.format_help()
279+
assert "--num-files" in parser.format_help()
280+
assert "--nvlink-only" in parser.format_help()
281+
assert "--protocol" in parser.format_help()
282+
assert "--rmm-pool-size" in parser.format_help()
283+
assert "--scheduler-address" in parser.format_help()
284+
assert "--scheduler-file" in parser.format_help()
285+
assert "--threads-per-worker" in parser.format_help()
286+
assert "--enable-spilling" in parser.format_help()
287+
assert "--set-torch-to-use-rmm" in parser.format_help()
288+
assert "--max-mem-gb-classifier" in parser.format_help()
289+
assert "rmm_pool_size" in parser._defaults
290+
assert "set_torch_to_use_rmm" in parser._defaults
291+
292+
assert "--input-data-dir" in parser.format_help()
293+
assert "--output-data-dir" in parser.format_help()
294+
assert "--input-file-type" in parser.format_help()
295+
assert "--input-file-extension" in parser.format_help()
296+
assert "--output-file-type" in parser.format_help()
297+
assert "--input-text-field" in parser.format_help()
298+
assert "--batch-size" in parser.format_help()
299+
assert "--pretrained-model-name-or-path" in parser.format_help()
300+
assert "--autocast" in parser.format_help()
301+
assert "--max-chars" in parser.format_help()
302+
303+
def test_add_distributed_classifier_cluster_args(self):
304+
parser = argparse.ArgumentParser()
305+
argument_helper = ArgumentHelper(parser)
306+
argument_helper.add_distributed_classifier_cluster_args()
307+
308+
assert "--device" in argument_helper.parser.format_help()
309+
assert "--files-per-partition" in argument_helper.parser.format_help()
310+
assert "--n-workers" in argument_helper.parser.format_help()
311+
assert "--num-files" in argument_helper.parser.format_help()
312+
assert "--nvlink-only" in argument_helper.parser.format_help()
313+
assert "--protocol" in argument_helper.parser.format_help()
314+
assert "--rmm-pool-size" in argument_helper.parser.format_help()
315+
assert "--scheduler-address" in argument_helper.parser.format_help()
316+
assert "--scheduler-file" in argument_helper.parser.format_help()
317+
assert "--threads-per-worker" in argument_helper.parser.format_help()
318+
assert "--enable-spilling" in argument_helper.parser.format_help()
319+
assert "--set-torch-to-use-rmm" in argument_helper.parser.format_help()
320+
assert "--max-mem-gb-classifier" in argument_helper.parser.format_help()
321+
assert "rmm_pool_size" in argument_helper.parser._defaults
322+
assert "set_torch_to_use_rmm" in argument_helper.parser._defaults
323+
324+
def test_parse_gpu_dedup_args(self):
325+
parser = argparse.ArgumentParser()
326+
argument_helper = ArgumentHelper(parser)
327+
argument_helper.parse_gpu_dedup_args()
328+
329+
assert "--device" in argument_helper.parser.format_help()
330+
assert "--files-per-partition" in argument_helper.parser.format_help()
331+
assert "--n-workers" in argument_helper.parser.format_help()
332+
assert "--num-files" in argument_helper.parser.format_help()
333+
assert "--nvlink-only" in argument_helper.parser.format_help()
334+
assert "--protocol" in argument_helper.parser.format_help()
335+
assert "--rmm-pool-size" in argument_helper.parser.format_help()
336+
assert "--scheduler-address" in argument_helper.parser.format_help()
337+
assert "--scheduler-file" in argument_helper.parser.format_help()
338+
assert "--threads-per-worker" in argument_helper.parser.format_help()
339+
340+
assert "device" in argument_helper.parser._defaults
341+
assert "set_torch_to_use_rmm" in argument_helper.parser._defaults
342+
343+
assert "--input-data-dirs" in argument_helper.parser.format_help()
344+
assert "--input-json-text-field" in argument_helper.parser.format_help()
345+
assert "--input-json-id-field" in argument_helper.parser.format_help()
346+
assert "--log-dir" in argument_helper.parser.format_help()
347+
assert "--profile-path" in argument_helper.parser.format_help()
348+
349+
def test_parse_semdedup_args(self):
350+
parser = argparse.ArgumentParser()
351+
argument_helper = ArgumentHelper(parser)
352+
parser = argument_helper.parse_semdedup_args()
353+
354+
assert "--device" in parser.format_help()
355+
assert "--files-per-partition" in parser.format_help()
356+
assert "--n-workers" in parser.format_help()
357+
assert "--num-files" in parser.format_help()
358+
assert "--nvlink-only" in parser.format_help()
359+
assert "--protocol" in parser.format_help()
360+
assert "--rmm-pool-size" in parser.format_help()
361+
assert "--scheduler-address" in parser.format_help()
362+
assert "--scheduler-file" in parser.format_help()
363+
assert "--threads-per-worker" in parser.format_help()
364+
365+
assert "--input-data-dir" in parser.format_help()
366+
assert "--input-file-extension" in parser.format_help()
367+
assert "--input-file-type" in parser.format_help()
368+
assert "--input-text-field" in parser.format_help()
369+
assert "--id-column" in parser.format_help()
370+
assert "--id-column-type" in parser.format_help()
371+
372+
assert "--config-file" in parser.format_help()
373+
374+
assert "rmm_pool_size" in parser._defaults
375+
assert "device" in parser._defaults
376+
assert "set_torch_to_use_rmm" in parser._defaults

0 commit comments

Comments
 (0)