|
19 | 19 | from ..utils.system import format_model_size, get_system_resources |
20 | 20 | from ..utils.progress import configure_hf_hub_progress |
21 | 21 | from ..utils.model_cache import model_cache_manager |
| 22 | +from ..utils.huggingface_search import hf_searcher |
22 | 23 | from ..logger.logger import logger |
23 | 24 |
|
24 | 25 | # Initialize rich console |
@@ -203,89 +204,216 @@ def remove(model_id: str, force: bool): |
203 | 204 | sys.exit(1) |
204 | 205 |
|
205 | 206 | @models.command() |
206 | | -@click.option('--search', help='Search models by name or description') |
207 | | -@click.option('--limit', type=int, default=20, help='Maximum number of models to show') |
| 207 | +@click.option('--search', help='Search models by keywords, tags, or description') |
| 208 | +@click.option('--limit', type=int, default=20, help='Maximum number of models to show (default: 20)') |
208 | 209 | @click.option('--format', 'output_format', type=click.Choice(['table', 'json']), default='table', |
209 | 210 | help='Output format (table or json)') |
210 | | -def discover(search: Optional[str], limit: int, output_format: str): |
211 | | - """Discover available models from HuggingFace Hub and registry""" |
| 211 | +@click.option('--registry-only', is_flag=True, help='Show only LocalLab registry models') |
| 212 | +@click.option('--hub-only', is_flag=True, help='Show only HuggingFace Hub models') |
| 213 | +@click.option('--sort', type=click.Choice(['downloads', 'likes', 'recent']), default='downloads', |
| 214 | + help='Sort HuggingFace models by downloads, likes, or recent updates') |
| 215 | +@click.option('--tags', help='Filter by tags (comma-separated, e.g., "conversational,chat")') |
| 216 | +def discover(search: Optional[str], limit: int, output_format: str, registry_only: bool, |
| 217 | + hub_only: bool, sort: str, tags: Optional[str]): |
| 218 | + """Discover available models from HuggingFace Hub and LocalLab registry""" |
212 | 219 | try: |
213 | 220 | console.print("🔍 Discovering available models...", style="blue") |
214 | 221 |
|
215 | | - # Start with registry models |
216 | | - available_models = [] |
| 222 | + all_models = [] |
217 | 223 |
|
218 | | - # Add registry models |
219 | | - for model_id, config in MODEL_REGISTRY.items(): |
220 | | - model_info = { |
221 | | - "id": model_id, |
222 | | - "name": config.get("name", model_id), |
223 | | - "description": config.get("description", ""), |
224 | | - "size": config.get("size", "Unknown"), |
225 | | - "type": "Registry", |
226 | | - "requirements": config.get("requirements", {}), |
227 | | - "is_cached": False |
228 | | - } |
229 | | - available_models.append(model_info) |
| 224 | + # Get registry models (unless hub-only is specified) |
| 225 | + if not hub_only: |
| 226 | + registry_models = _get_registry_models() |
| 227 | + all_models.extend(registry_models) |
| 228 | + console.print(f"📚 Found {len(registry_models)} LocalLab registry models", style="dim") |
| 229 | + |
| 230 | + # Get HuggingFace Hub models (unless registry-only is specified) |
| 231 | + if not registry_only: |
| 232 | + hf_models, hf_success = _get_huggingface_models(search, limit, sort, tags) |
| 233 | + if hf_success: |
| 234 | + all_models.extend(hf_models) |
| 235 | + console.print(f"🤗 Found {len(hf_models)} HuggingFace Hub models", style="dim") |
| 236 | + else: |
| 237 | + console.print("⚠️ Could not search HuggingFace Hub (network issue or missing dependencies)", style="yellow") |
| 238 | + if search or tags: |
| 239 | + console.print("💡 Try using --registry-only to search LocalLab registry models only.", style="dim") |
| 240 | + |
| 241 | + # Apply search filter to registry models if search is specified |
| 242 | + if search and not hub_only: |
| 243 | + search_lower = search.lower() |
| 244 | + registry_filtered = [ |
| 245 | + m for m in all_models |
| 246 | + if m.get("type") == "Registry" and ( |
| 247 | + search_lower in m["name"].lower() or |
| 248 | + search_lower in m["description"].lower() |
| 249 | + ) |
| 250 | + ] |
| 251 | + # Keep HF models and filtered registry models |
| 252 | + hf_models_in_list = [m for m in all_models if m.get("type") == "HuggingFace"] |
| 253 | + all_models = registry_filtered + hf_models_in_list |
230 | 254 |
|
231 | 255 | # Check which models are already cached |
232 | 256 | cached_models = model_cache_manager.get_cached_models() |
233 | 257 | cached_ids = {m["id"] for m in cached_models} |
234 | 258 |
|
235 | | - for model in available_models: |
| 259 | + for model in all_models: |
236 | 260 | model["is_cached"] = model["id"] in cached_ids |
237 | 261 |
|
238 | | - # Apply search filter |
239 | | - if search: |
240 | | - search_lower = search.lower() |
241 | | - available_models = [ |
242 | | - m for m in available_models |
243 | | - if search_lower in m["name"].lower() or search_lower in m["description"].lower() |
244 | | - ] |
| 262 | + # Sort models: Registry first, then by specified sort order |
| 263 | + all_models.sort(key=lambda x: ( |
| 264 | + 0 if x.get("type") == "Registry" else 1, # Registry models first |
| 265 | + -x.get("downloads", 0) if sort == "downloads" else 0, |
| 266 | + -x.get("likes", 0) if sort == "likes" else 0, |
| 267 | + x.get("updated_at", "") if sort == "recent" else "" |
| 268 | + )) |
245 | 269 |
|
246 | | - # Limit results |
247 | | - available_models = available_models[:limit] |
| 270 | + # Apply final limit |
| 271 | + all_models = all_models[:limit] |
248 | 272 |
|
249 | 273 | if output_format == 'json': |
250 | | - click.echo(json.dumps(available_models, indent=2)) |
| 274 | + click.echo(json.dumps(all_models, indent=2)) |
251 | 275 | return |
252 | 276 |
|
253 | | - if not available_models: |
| 277 | + if not all_models: |
254 | 278 | console.print("📭 No models found matching your criteria.", style="yellow") |
| 279 | + if not registry_only: |
| 280 | + console.print("💡 Try adjusting your search terms or check your internet connection.", style="dim") |
255 | 281 | return |
256 | 282 |
|
257 | 283 | # Create table |
258 | 284 | table = Table(title="🌟 Available Models") |
259 | | - table.add_column("Model ID", style="cyan", no_wrap=True) |
260 | | - table.add_column("Name", style="green") |
261 | | - table.add_column("Size", style="magenta") |
| 285 | + table.add_column("Model ID", style="cyan", no_wrap=True, max_width=30) |
| 286 | + table.add_column("Name", style="green", max_width=20) |
| 287 | + table.add_column("Size", style="magenta", justify="right") |
262 | 288 | table.add_column("Type", style="blue") |
263 | | - table.add_column("Status", style="yellow") |
264 | | - table.add_column("Description", style="dim") |
| 289 | + table.add_column("Downloads", style="yellow", justify="right") |
| 290 | + table.add_column("Status", style="bright_green") |
| 291 | + table.add_column("Description", style="dim", max_width=40) |
265 | 292 |
|
266 | | - for model in available_models: |
| 293 | + for model in all_models: |
267 | 294 | status = "✅ Cached" if model["is_cached"] else "📥 Available" |
| 295 | + downloads_str = "" |
| 296 | + if model.get("downloads", 0) > 0: |
| 297 | + downloads = model["downloads"] |
| 298 | + if downloads >= 1000000: |
| 299 | + downloads_str = f"{downloads/1000000:.1f}M" |
| 300 | + elif downloads >= 1000: |
| 301 | + downloads_str = f"{downloads/1000:.1f}K" |
| 302 | + else: |
| 303 | + downloads_str = str(downloads) |
| 304 | + |
268 | 305 | table.add_row( |
269 | 306 | model["id"], |
270 | 307 | model["name"], |
271 | | - model["size"], |
272 | | - model["type"], |
| 308 | + model.get("size", "Unknown"), |
| 309 | + model.get("type", "Unknown"), |
| 310 | + downloads_str, |
273 | 311 | status, |
274 | 312 | model["description"][:50] + "..." if len(model["description"]) > 50 else model["description"] |
275 | 313 | ) |
276 | 314 |
|
277 | 315 | console.print(table) |
278 | 316 |
|
279 | 317 | # Show summary |
280 | | - cached_count = sum(1 for m in available_models if m["is_cached"]) |
281 | | - console.print(f"\n📊 Found {len(available_models)} models ({cached_count} cached, {len(available_models) - cached_count} available for download)") |
| 318 | + cached_count = sum(1 for m in all_models if m["is_cached"]) |
| 319 | + registry_count = sum(1 for m in all_models if m.get("type") == "Registry") |
| 320 | + hf_count = len(all_models) - registry_count |
| 321 | + |
| 322 | + console.print(f"\n📊 Found {len(all_models)} models:") |
| 323 | + console.print(f" • {registry_count} LocalLab registry models") |
| 324 | + console.print(f" • {hf_count} HuggingFace Hub models") |
| 325 | + console.print(f" • {cached_count} already cached locally") |
282 | 326 | console.print("\n💡 Use 'locallab models download <model_id>' to download a model locally.") |
283 | 327 |
|
| 328 | + if not registry_only and hf_count > 0: |
| 329 | + console.print("🔍 Use --search to find specific models or --tags to filter by categories.") |
| 330 | + |
284 | 331 | except Exception as e: |
285 | 332 | logger.error(f"Error discovering models: {e}") |
286 | 333 | console.print(f"❌ Error discovering models: {str(e)}", style="red") |
287 | 334 | sys.exit(1) |
288 | 335 |
|
| 336 | +def _get_registry_models(): |
| 337 | + """Get models from LocalLab registry""" |
| 338 | + registry_models = [] |
| 339 | + |
| 340 | + for model_id, config in MODEL_REGISTRY.items(): |
| 341 | + model_info = { |
| 342 | + "id": model_id, |
| 343 | + "name": config.get("name", model_id), |
| 344 | + "description": config.get("description", "LocalLab registry model"), |
| 345 | + "size": config.get("size", "Unknown"), |
| 346 | + "type": "Registry", |
| 347 | + "downloads": 0, # Registry models don't have download counts |
| 348 | + "likes": 0, |
| 349 | + "requirements": config.get("requirements", {}), |
| 350 | + "is_cached": False, |
| 351 | + "tags": [], |
| 352 | + "author": "LocalLab", |
| 353 | + "updated_at": "" |
| 354 | + } |
| 355 | + registry_models.append(model_info) |
| 356 | + |
| 357 | + return registry_models |
| 358 | + |
| 359 | +def _get_huggingface_models(search: Optional[str], limit: int, sort: str, tags: Optional[str]): |
| 360 | + """Get models from HuggingFace Hub""" |
| 361 | + try: |
| 362 | + # Parse tags if provided |
| 363 | + tag_list = [] |
| 364 | + if tags: |
| 365 | + tag_list = [tag.strip() for tag in tags.split(',') if tag.strip()] |
| 366 | + |
| 367 | + # Convert sort parameter |
| 368 | + hf_sort = "downloads" |
| 369 | + if sort == "likes": |
| 370 | + hf_sort = "likes" |
| 371 | + elif sort == "recent": |
| 372 | + hf_sort = "lastModified" |
| 373 | + |
| 374 | + # Search HuggingFace Hub |
| 375 | + if search: |
| 376 | + hf_models, success = hf_searcher.search_models( |
| 377 | + search_query=search, limit=limit, sort=hf_sort |
| 378 | + ) |
| 379 | + elif tag_list: |
| 380 | + hf_models, success = hf_searcher.search_models( |
| 381 | + search_query=None, limit=limit, sort=hf_sort, filter_tags=tag_list |
| 382 | + ) |
| 383 | + else: |
| 384 | + hf_models, success = hf_searcher.search_models( |
| 385 | + search_query=None, limit=limit, sort=hf_sort |
| 386 | + ) |
| 387 | + |
| 388 | + if not success: |
| 389 | + return [], False |
| 390 | + |
| 391 | + # Convert to our format |
| 392 | + converted_models = [] |
| 393 | + for hf_model in hf_models: |
| 394 | + model_info = { |
| 395 | + "id": hf_model.id, |
| 396 | + "name": hf_model.name, |
| 397 | + "description": hf_model.description, |
| 398 | + "size": hf_model.size_formatted, |
| 399 | + "type": "HuggingFace", |
| 400 | + "downloads": hf_model.downloads, |
| 401 | + "likes": hf_model.likes, |
| 402 | + "is_cached": False, |
| 403 | + "tags": hf_model.tags, |
| 404 | + "author": hf_model.author, |
| 405 | + "updated_at": hf_model.updated_at or "", |
| 406 | + "pipeline_tag": hf_model.pipeline_tag, |
| 407 | + "library_name": hf_model.library_name |
| 408 | + } |
| 409 | + converted_models.append(model_info) |
| 410 | + |
| 411 | + return converted_models, True |
| 412 | + |
| 413 | + except Exception as e: |
| 414 | + logger.debug(f"Error getting HuggingFace models: {e}") |
| 415 | + return [], False |
| 416 | + |
289 | 417 | @models.command() |
290 | 418 | @click.argument('model_id') |
291 | 419 | def info(model_id: str): |
|
0 commit comments