|
| 1 | +""" |
| 2 | +Hugging Face Hub file source plugin using fsspec. |
| 3 | +""" |
| 4 | + |
| 5 | +import logging |
| 6 | +from typing import ( |
| 7 | + Annotated, |
| 8 | + Literal, |
| 9 | + Optional, |
| 10 | + Union, |
| 11 | +) |
| 12 | + |
| 13 | +from fsspec import AbstractFileSystem |
| 14 | +from pydantic import Field |
| 15 | + |
| 16 | +from galaxy.files.models import ( |
| 17 | + AnyRemoteEntry, |
| 18 | + FilesSourceRuntimeContext, |
| 19 | + RemoteDirectory, |
| 20 | +) |
| 21 | + |
| 22 | +try: |
| 23 | + from huggingface_hub import ( |
| 24 | + HfApi, |
| 25 | + HfFileSystem, |
| 26 | + ) |
| 27 | +except ImportError: |
| 28 | + HfApi = None |
| 29 | + HfFileSystem = None |
| 30 | + |
| 31 | +from galaxy.exceptions import MessageException |
| 32 | +from galaxy.files.sources._fsspec import ( |
| 33 | + CacheOptionsDictType, |
| 34 | + FsspecBaseFileSourceConfiguration, |
| 35 | + FsspecBaseFileSourceTemplateConfiguration, |
| 36 | + FsspecFilesSource, |
| 37 | +) |
| 38 | +from galaxy.util.config_templates import TemplateExpansion |
| 39 | + |
| 40 | +log = logging.getLogger(__name__) |
| 41 | + |
| 42 | +SortByOptions = Literal["last_modified", "trending_score", "created_at", "downloads", "likes"] |
| 43 | + |
| 44 | +DEFAULT_SORT_BY: SortByOptions = "downloads" |
| 45 | + |
| 46 | +MAX_REPO_LIMIT = 1000 |
| 47 | + |
| 48 | + |
| 49 | +class HuggingFaceFileSourceTemplateConfiguration(FsspecBaseFileSourceTemplateConfiguration): |
| 50 | + token: Annotated[ |
| 51 | + Union[str, TemplateExpansion, None], |
| 52 | + Field( |
| 53 | + description="Hugging Face API token for accessing private model repositories. " |
| 54 | + "If not provided, only public repositories will be accessible.", |
| 55 | + ), |
| 56 | + ] = None |
| 57 | + endpoint: Annotated[ |
| 58 | + Union[str, TemplateExpansion, None], |
| 59 | + Field( |
| 60 | + description="Custom endpoint for Hugging Face Hub. " |
| 61 | + "If not provided, the default Hugging Face Hub will be used (https://huggingface.co).", |
| 62 | + ), |
| 63 | + ] = None |
| 64 | + |
| 65 | + |
| 66 | +class HuggingFaceFileSourceConfiguration(FsspecBaseFileSourceConfiguration): |
| 67 | + token: Optional[str] = None |
| 68 | + endpoint: Optional[str] = None |
| 69 | + |
| 70 | + |
| 71 | +class HuggingFaceFilesSource( |
| 72 | + FsspecFilesSource[HuggingFaceFileSourceTemplateConfiguration, HuggingFaceFileSourceConfiguration] |
| 73 | +): |
| 74 | + plugin_type = "huggingface" |
| 75 | + required_module = HfFileSystem |
| 76 | + required_package = "huggingface_hub" |
| 77 | + |
| 78 | + template_config_class = HuggingFaceFileSourceTemplateConfiguration |
| 79 | + resolved_config_class = HuggingFaceFileSourceConfiguration |
| 80 | + |
| 81 | + def _open_fs( |
| 82 | + self, |
| 83 | + context: FilesSourceRuntimeContext[HuggingFaceFileSourceConfiguration], |
| 84 | + cache_options: CacheOptionsDictType, |
| 85 | + ) -> AbstractFileSystem: |
| 86 | + if HfFileSystem is None: |
| 87 | + raise self.required_package_exception |
| 88 | + |
| 89 | + config = context.config |
| 90 | + return HfFileSystem( |
| 91 | + token=config.token or False, # Use False to disable authentication |
| 92 | + endpoint=config.endpoint, |
| 93 | + **cache_options, |
| 94 | + ) |
| 95 | + |
| 96 | + def _to_filesystem_path(self, path: str) -> str: |
| 97 | + """Transform entry path to Hugging Face filesystem path.""" |
| 98 | + if path == "/": |
| 99 | + # Hugging Face does not implement access to the repositories root |
| 100 | + return "" |
| 101 | + # Remove leading slash for HF compatibility |
| 102 | + return path.lstrip("/") |
| 103 | + |
| 104 | + def _extract_timestamp(self, info: dict) -> Optional[str]: |
| 105 | + """Extract timestamp from Hugging Face file info to use it in the RemoteFile entry.""" |
| 106 | + last_commit: dict = info.get("last_commit", {}) |
| 107 | + return last_commit.get("date") |
| 108 | + |
| 109 | + def _list( |
| 110 | + self, |
| 111 | + context: FilesSourceRuntimeContext[HuggingFaceFileSourceConfiguration], |
| 112 | + path="/", |
| 113 | + recursive=False, |
| 114 | + write_intent: bool = False, |
| 115 | + limit: Optional[int] = None, |
| 116 | + offset: Optional[int] = None, |
| 117 | + query: Optional[str] = None, |
| 118 | + sort_by: Optional[str] = None, |
| 119 | + ) -> tuple[list[AnyRemoteEntry], int]: |
| 120 | + # If we're at the root, list repositories using HfApi |
| 121 | + if path == "/": |
| 122 | + return self._list_repositories(config=context.config, limit=limit, offset=offset, query=query) |
| 123 | + |
| 124 | + # For non-root paths, use the parent implementation |
| 125 | + return super()._list( |
| 126 | + context=context, |
| 127 | + path=path, |
| 128 | + recursive=recursive, |
| 129 | + limit=limit, |
| 130 | + offset=offset, |
| 131 | + query=query, |
| 132 | + sort_by=sort_by, |
| 133 | + ) |
| 134 | + |
| 135 | + def _list_repositories( |
| 136 | + self, |
| 137 | + config: HuggingFaceFileSourceConfiguration, |
| 138 | + limit: Optional[int] = None, |
| 139 | + offset: Optional[int] = None, |
| 140 | + query: Optional[str] = None, |
| 141 | + ) -> tuple[list[AnyRemoteEntry], int]: |
| 142 | + if HfApi is None: |
| 143 | + raise self.required_package_exception |
| 144 | + |
| 145 | + api = HfApi( |
| 146 | + token=config.token or False, # Use False to disable authentication |
| 147 | + endpoint=config.endpoint, |
| 148 | + ) |
| 149 | + try: |
| 150 | + repos_iter = api.list_models(search=query, sort=DEFAULT_SORT_BY, direction=-1, limit=MAX_REPO_LIMIT) |
| 151 | + |
| 152 | + # Convert repositories to directory entries |
| 153 | + entries_list: list[AnyRemoteEntry] = [] |
| 154 | + for repo in repos_iter: |
| 155 | + repo_id = repo.id if hasattr(repo, "id") else str(repo) |
| 156 | + entry = RemoteDirectory( |
| 157 | + name=repo_id, |
| 158 | + uri=self.uri_from_path(repo_id), |
| 159 | + path=repo_id, |
| 160 | + ) |
| 161 | + entries_list.append(entry) |
| 162 | + |
| 163 | + total_count = len(entries_list) |
| 164 | + |
| 165 | + # Apply pagination |
| 166 | + if offset is not None and limit is not None: |
| 167 | + entries_list = entries_list[offset : offset + limit] |
| 168 | + elif limit is not None: |
| 169 | + entries_list = entries_list[:limit] |
| 170 | + |
| 171 | + return entries_list, total_count |
| 172 | + |
| 173 | + except Exception as e: |
| 174 | + raise MessageException(f"Failed to list repositories from Hugging Face Hub: {e}") from e |
| 175 | + |
| 176 | + |
| 177 | +__all__ = ["HuggingFaceFilesSource"] |
0 commit comments