@@ -31,47 +31,74 @@ class DatastoreSettings(BaseModel):
3131 token : Optional [str ] = Field (default = None , description = "If needed, token to use for authentication." )
3232
3333
34- def get_file_column_names (file_path : Union [str , Path ], file_type : str ) -> list [str ]:
35- """Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
36- file_path = Path (file_path )
37- if "*" in str (file_path ):
38- matching_files = sorted (file_path .parent .glob (file_path .name ))
39- if not matching_files :
40- raise InvalidFilePathError (f"🛑 No files found matching pattern: { str (file_path )!r} " )
41- logger .debug (f"0️⃣ Using the first matching file in { str (file_path )!r} to determine column names in seed dataset" )
42- file_path = matching_files [0 ]
34+ def get_file_column_names (file_reference : Union [str , Path , HfFileSystem ], file_type : str ) -> list [str ]:
35+ """Get column names from a dataset file.
36+
37+ Args:
38+ file_reference: Path to the dataset file, or an HfFileSystem object.
39+ file_type: Type of the dataset file. Must be one of: 'parquet', 'json', 'jsonl', 'csv'.
4340
41+ Raises:
42+ InvalidFilePathError: If the file type is not supported.
43+
44+ Returns:
45+ List of column names.
46+ """
4447 if file_type == "parquet" :
4548 try :
46- schema = pq .read_schema (file_path )
49+ schema = pq .read_schema (file_reference )
4750 if hasattr (schema , "names" ):
4851 return schema .names
4952 else :
5053 return [field .name for field in schema ]
5154 except Exception as e :
52- logger .warning (f"Failed to process parquet file { file_path } : { e } " )
55+ logger .warning (f"Failed to process parquet file { file_reference } : { e } " )
5356 return []
5457 elif file_type in ["json" , "jsonl" ]:
55- return pd .read_json (file_path , orient = "records" , lines = True , nrows = 1 ).columns .tolist ()
58+ return pd .read_json (file_reference , orient = "records" , lines = True , nrows = 1 ).columns .tolist ()
5659 elif file_type == "csv" :
5760 try :
58- df = pd .read_csv (file_path , nrows = 1 )
61+ df = pd .read_csv (file_reference , nrows = 1 )
5962 return df .columns .tolist ()
6063 except (pd .errors .EmptyDataError , pd .errors .ParserError ) as e :
61- logger .warning (f"Failed to process CSV file { file_path } : { e } " )
64+ logger .warning (f"Failed to process CSV file { file_reference } : { e } " )
6265 return []
6366 else :
6467 raise InvalidFilePathError (f"🛑 Unsupported file type: { file_type !r} " )
6568
6669
6770def fetch_seed_dataset_column_names (seed_dataset_reference : SeedDatasetReference ) -> list [str ]:
6871 if hasattr (seed_dataset_reference , "datastore_settings" ):
69- return _fetch_seed_dataset_column_names_from_datastore (
72+ return fetch_seed_dataset_column_names_from_datastore (
7073 seed_dataset_reference .repo_id ,
7174 seed_dataset_reference .filename ,
7275 seed_dataset_reference .datastore_settings ,
7376 )
74- return _fetch_seed_dataset_column_names_from_local_file (seed_dataset_reference .dataset )
77+ return fetch_seed_dataset_column_names_from_local_file (seed_dataset_reference .dataset )
78+
79+
80+ def fetch_seed_dataset_column_names_from_datastore (
81+ repo_id : str ,
82+ filename : str ,
83+ datastore_settings : Optional [Union [DatastoreSettings , dict ]] = None ,
84+ ) -> list [str ]:
85+ file_type = filename .split ("." )[- 1 ]
86+ if f".{ file_type } " not in VALID_DATASET_FILE_EXTENSIONS :
87+ raise InvalidFileFormatError (f"🛑 Unsupported file type: { filename !r} " )
88+
89+ datastore_settings = resolve_datastore_settings (datastore_settings )
90+ fs = HfFileSystem (endpoint = datastore_settings .endpoint , token = datastore_settings .token , skip_instance_cache = True )
91+
92+ file_path = _extract_single_file_path_from_glob_pattern_if_present (f"datasets/{ repo_id } /{ filename } " , fs = fs )
93+
94+ with fs .open (file_path ) as f :
95+ return get_file_column_names (f , file_type )
96+
97+
98+ def fetch_seed_dataset_column_names_from_local_file (dataset_path : str | Path ) -> list [str ]:
99+ dataset_path = _validate_dataset_path (dataset_path , allow_glob_pattern = True )
100+ dataset_path = _extract_single_file_path_from_glob_pattern_if_present (dataset_path )
101+ return get_file_column_names (dataset_path , str (dataset_path ).split ("." )[- 1 ])
75102
76103
77104def resolve_datastore_settings (datastore_settings : DatastoreSettings | dict | None ) -> DatastoreSettings :
@@ -114,25 +141,34 @@ def upload_to_hf_hub(
114141 return f"{ repo_id } /{ filename } "
115142
116143
117- def _fetch_seed_dataset_column_names_from_datastore (
118- repo_id : str ,
119- filename : str ,
120- datastore_settings : Optional [Union [DatastoreSettings , dict ]] = None ,
121- ) -> list [str ]:
122- file_type = filename .split ("." )[- 1 ]
123- if f".{ file_type } " not in VALID_DATASET_FILE_EXTENSIONS :
124- raise InvalidFileFormatError (f"🛑 Unsupported file type: { filename !r} " )
125-
126- datastore_settings = resolve_datastore_settings (datastore_settings )
127- fs = HfFileSystem (endpoint = datastore_settings .endpoint , token = datastore_settings .token , skip_instance_cache = True )
128-
129- with fs .open (f"datasets/{ repo_id } /{ filename } " ) as f :
130- return get_file_column_names (f , file_type )
131-
144+ def _extract_single_file_path_from_glob_pattern_if_present (
145+ file_path : str | Path ,
146+ fs : HfFileSystem | None = None ,
147+ ) -> Path :
148+ file_path = Path (file_path )
132149
133- def _fetch_seed_dataset_column_names_from_local_file (dataset_path : str | Path ) -> list [str ]:
134- dataset_path = _validate_dataset_path (dataset_path , allow_glob_pattern = True )
135- return get_file_column_names (dataset_path , str (dataset_path ).split ("." )[- 1 ])
150+ # no glob pattern
151+ if "*" not in str (file_path ):
152+ return file_path
153+
154+ # glob pattern with HfFileSystem
155+ if fs is not None :
156+ file_to_check = None
157+ file_extension = file_path .name .split ("." )[- 1 ]
158+ for file in fs .ls (str (file_path .parent )):
159+ filename = file ["name" ]
160+ if filename .endswith (f".{ file_extension } " ):
161+ file_to_check = filename
162+ if file_to_check is None :
163+ raise InvalidFilePathError (f"🛑 No files found matching pattern: { str (file_path )!r} " )
164+ logger .debug (f"Using the first matching file in { str (file_path )!r} to determine column names in seed dataset" )
165+ return Path (file_to_check )
166+
167+ # glob pattern with local file system
168+ if not (matching_files := sorted (file_path .parent .glob (file_path .name ))):
169+ raise InvalidFilePathError (f"🛑 No files found matching pattern: { str (file_path )!r} " )
170+ logger .debug (f"Using the first matching file in { str (file_path )!r} to determine column names in seed dataset" )
171+ return matching_files [0 ]
136172
137173
138174def _validate_dataset_path (dataset_path : Union [str , Path ], allow_glob_pattern : bool = False ) -> Path :
0 commit comments