|
17 | 17 |
|
18 | 18 | import itertools |
19 | 19 | import os |
| 20 | +import posixpath |
20 | 21 | import re |
21 | 22 |
|
22 | 23 |
|
@@ -46,33 +47,33 @@ def snakecase_to_camelcase(name): |
46 | 47 |
|
47 | 48 |
|
48 | 49 | def filename_prefix_for_name(name): |
49 | | - if os.path.basename(name) != name: |
| 50 | + if posixpath.basename(name) != name: |
50 | 51 | raise ValueError(f"Should be a dataset name, not a path: {name}") |
51 | 52 | return camelcase_to_snakecase(name) |
52 | 53 |
|
53 | 54 |
|
54 | 55 | def filename_prefix_for_split(name, split): |
55 | | - if os.path.basename(name) != name: |
| 56 | + if posixpath.basename(name) != name: |
56 | 57 | raise ValueError(f"Should be a dataset name, not a path: {name}") |
57 | 58 | if not re.match(_split_re, split): |
58 | 59 | raise ValueError(f"Split name should match '{_split_re}'' but got '{split}'.") |
59 | 60 | return f"{filename_prefix_for_name(name)}-{split}" |
60 | 61 |
|
61 | 62 |
|
62 | | -def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix=None): |
| 63 | +def filepattern_for_dataset_split(path, dataset_name, split, filetype_suffix=None): |
63 | 64 | prefix = filename_prefix_for_split(dataset_name, split) |
| 65 | + filepath = posixpath.join(path, prefix) |
| 66 | + filepath = f"{filepath}*" |
64 | 67 | if filetype_suffix: |
65 | | - prefix += f".{filetype_suffix}" |
66 | | - filepath = os.path.join(data_dir, prefix) |
67 | | - return f"{filepath}*" |
| 68 | + filepath += f".{filetype_suffix}" |
| 69 | + return filepath |
68 | 70 |
|
69 | 71 |
|
70 | | -def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None): |
| 72 | +def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, num_shards=1): |
71 | 73 | prefix = filename_prefix_for_split(dataset_name, split) |
72 | | - prefix = os.path.join(path, prefix) |
| 74 | + prefix = posixpath.join(path, prefix) |
73 | 75 |
|
74 | | - if shard_lengths: |
75 | | - num_shards = len(shard_lengths) |
| 76 | + if num_shards > 1: |
76 | 77 | filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)] |
77 | 78 | if filetype_suffix: |
78 | 79 | filenames = [filename + f".{filetype_suffix}" for filename in filenames] |
|
0 commit comments