22# SPDX-License-Identifier: Apache-2.0
33
44from abc import ABC , abstractmethod
5- import os
6- import tempfile
75
8- from datasets import DatasetDict , load_dataset
96import duckdb
107from huggingface_hub import HfApi , HfFileSystem
11- import pandas as pd
128
13- from data_designer .config .utils .io_helpers import validate_dataset_file_path
149from data_designer .logging import quiet_noisy_logger
1510
1611quiet_noisy_logger ("httpx" )
@@ -31,9 +26,6 @@ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
3126 @abstractmethod
3227 def get_dataset_uri (self , file_id : str ) -> str : ...
3328
34- @abstractmethod
35- def load_dataset (self , file_id : str ) -> pd .DataFrame : ...
36-
3729
3830class LocalSeedDatasetDataStore (SeedDatasetDataStore ):
3931 """Local filesystem-based dataset storage."""
@@ -44,20 +36,6 @@ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
4436 def get_dataset_uri (self , file_id : str ) -> str :
4537 return file_id
4638
47- def load_dataset (self , file_id : str ) -> pd .DataFrame :
48- filepath = validate_dataset_file_path (file_id )
49- match filepath .suffix .lower ():
50- case ".csv" :
51- return pd .read_csv (filepath )
52- case ".parquet" :
53- return pd .read_parquet (filepath )
54- case ".json" :
55- return pd .read_json (filepath , lines = True )
56- case ".jsonl" :
57- return pd .read_json (filepath , lines = True )
58- case _:
59- raise ValueError ("Local datasets must be CSV, Parquet, JSON, or JSONL" )
60-
6139
6240class HfHubSeedDatasetDataStore (SeedDatasetDataStore ):
6341 """Hugging Face and Data Store dataset storage."""
@@ -76,55 +54,6 @@ def get_dataset_uri(self, file_id: str) -> str:
7654 repo_id , filename = self ._get_repo_id_and_filename (identifier )
7755 return f"{ _HF_DATASETS_PREFIX } { repo_id } /{ filename } "
7856
79- def load_dataset (self , file_id : str ) -> pd .DataFrame :
80- identifier = file_id .removeprefix (_HF_DATASETS_PREFIX )
81- repo_id , filename = self ._get_repo_id_and_filename (identifier )
82- is_file = "." in file_id .split ("/" )[- 1 ]
83-
84- self ._validate_repo (repo_id )
85-
86- if is_file :
87- self ._validate_file (repo_id , filename )
88- return self ._download_and_load_file (repo_id , filename )
89- else :
90- return self ._download_and_load_directory (repo_id , filename )
91-
92- def _validate_repo (self , repo_id : str ) -> None :
93- """Validate that the repository exists and is a dataset repo."""
94- if not self .hfapi .repo_exists (repo_id , repo_type = "dataset" ):
95- if self .hfapi .repo_exists (repo_id , repo_type = "model" ):
96- raise FileNotFoundError (f"Repo { repo_id } is a model repo, not a dataset repo" )
97- raise FileNotFoundError (f"Repo { repo_id } does not exist" )
98-
99- def _validate_file (self , repo_id : str , filename : str ) -> None :
100- """Validate that the file exists in the repository."""
101- if not self .hfapi .file_exists (repo_id , filename , repo_type = "dataset" ):
102- raise FileNotFoundError (f"File { filename } does not exist in repo { repo_id } " )
103-
104- def _download_and_load_file (self , repo_id : str , filename : str ) -> pd .DataFrame :
105- """Download a specific file and load it as a dataset."""
106- with tempfile .TemporaryDirectory () as temp_dir :
107- self .hfapi .hf_hub_download (
108- repo_id = repo_id ,
109- filename = filename ,
110- local_dir = temp_dir ,
111- repo_type = "dataset" ,
112- )
113- return self ._load_local_dataset (temp_dir )
114-
115- def _download_and_load_directory (self , repo_id : str , directory : str ) -> pd .DataFrame :
116- """Download entire repo and load from specific subdirectory."""
117- with tempfile .TemporaryDirectory () as temp_dir :
118- self .hfapi .snapshot_download (
119- repo_id = repo_id ,
120- local_dir = temp_dir ,
121- repo_type = "dataset" ,
122- )
123- dataset_path = os .path .join (temp_dir , directory )
124- if not os .path .exists (dataset_path ):
125- dataset_path = temp_dir
126- return self ._load_local_dataset (dataset_path )
127-
12857 def _get_repo_id_and_filename (self , identifier : str ) -> tuple [str , str ]:
12958 """Extract repo_id and filename from identifier."""
13059 parts = identifier .split ("/" , 2 )
@@ -135,10 +64,3 @@ def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]:
13564 )
13665 repo_ns , repo_name , filename = parts
13766 return f"{ repo_ns } /{ repo_name } " , filename
138-
139- def _load_local_dataset (self , path : str ) -> pd .DataFrame :
140- """Load dataset from local path."""
141- hf_dataset = load_dataset (path = path )
142- if isinstance (hf_dataset , DatasetDict ):
143- hf_dataset = hf_dataset [list (hf_dataset .keys ())[0 ]]
144- return hf_dataset .to_pandas ()
0 commit comments