Skip to content

Commit cf2d23a

Browse files
authored
Try to read the data path arguments directly from a file (#254)
1 parent 99c8fe0 commit cf2d23a

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

megatron/arguments.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,26 @@ def __call__(self, parser, args, values, option_string=None):
841841
'test will be run on each of those groups independently',
842842
action=parse_data_paths)
843843

844+
class parse_data_paths_path(argparse.Action):
845+
def __call__(self, parser, args, values, option_string=None):
846+
expected_option_strings = ["--train-weighted-split-paths-path", "--valid-weighted-split-paths-path", "--test-weighted-split-paths-path"]
847+
assert option_string in expected_option_strings, f"Expected {option_string} to be in {expected_option_strings}"
848+
849+
with open(values, "r") as fi:
850+
lines = fi.readlines()
851+
assert len(lines) == 1, f"Got multiple lines {len(lines)} instead of 1 expected"
852+
assert lines[0][-2:] == "\"\n" and lines[0][0] == "\"", f"Invalid input format, got {lines}"
853+
values = lines[0][1:-2].split("\" \"")
854+
weighted_split_paths_dest = re.sub(r"_path$", "", self.dest)
855+
weighted_split_paths_option = re.sub(r"-path$", "", self.option_strings[0])
856+
setattr(args, weighted_split_paths_dest, values)
857+
parse_data_paths(option_strings=[weighted_split_paths_option], dest=weighted_split_paths_dest)(parser, args, values, option_string=weighted_split_paths_option)
858+
859+
860+
group.add_argument('--train-weighted-split-paths-path', type=str, action=parse_data_paths_path ,default=None)
861+
group.add_argument('--valid-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None)
862+
group.add_argument('--test-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None)
863+
844864
group.add_argument('--log-path', type=str, default=None,
845865
help='Path to the save arguments file.')
846866
group.add_argument('--vocab-file', type=str, default=None,

0 commit comments

Comments
 (0)