Skip to content

Commit 574191a

Browse files
committed
Enhance model handling and CLI functionality in pipeline generator
- Added support for listing all models regardless of file extension with the `--all` flag in the CLI. - Updated the model listing logic to differentiate between verified and non-verified models. - Enhanced the model information structure to include detected file extensions and primary extension for better clarity. - Improved documentation to reflect the new model handling capabilities and supported file formats. Signed-off-by: Victor Chang <[email protected]>
1 parent 59de165 commit 574191a

File tree

8 files changed

+606
-65
lines changed

8 files changed

+606
-65
lines changed

tools/pipeline-generator/docs/design.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ As of MONAI Core, bundles can also be exported in a **Hugging Face-compatible fo
3030

3131
- The tool does not convert input formats given that each model may expect a different type of input
3232
- The tool does not convert output formats given that each model may output a different type of result
33+
- The tool supports only torchscript (ts) models
3334

3435
## **Scope**
3536

tools/pipeline-generator/pipeline_generator/cli/main.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -64,61 +64,74 @@ def cli(ctx: click.Context, config: Optional[str]) -> None:
6464
)
6565
@click.option("--bundles-only", "-b", is_flag=True, help="Show only MONAI Bundles")
6666
@click.option("--tested-only", "-t", is_flag=True, help="Show only tested models")
67+
@click.option("--all", is_flag=True, help="Show all models regardless of file extension")
6768
@click.pass_context
68-
def list(ctx: click.Context, format: str, bundles_only: bool, tested_only: bool) -> None:
69+
def list(ctx: click.Context, format: str, bundles_only: bool, tested_only: bool, all: bool) -> None:
6970
"""List available models from configured endpoints.
7071
72+
By default, only shows models with TorchScript (.ts) files.
73+
Use --all to show models with any supported extension.
74+
7175
Args:
7276
ctx: Click context containing configuration
7377
format: Output format (table, simple, or json)
7478
bundles_only: If True, show only MONAI Bundles
7579
tested_only: If True, show only tested models
80+
all: If True, show all models regardless of file extension
7681
7782
Example:
7883
pg list --format table --bundles-only
84+
pg list --all # Show all models
7985
"""
8086

8187
# Load configuration
8288
config_path = ctx.obj.get("config_path")
8389
settings = load_config(config_path)
8490

85-
# Get set of tested model IDs from configuration
86-
tested_models = set()
91+
# Get set of verified model IDs from configuration
92+
verified_models = set()
8793
for endpoint in settings.endpoints:
8894
for model in endpoint.models:
89-
tested_models.add(model.model_id)
95+
verified_models.add(model.model_id)
9096

91-
# Create HuggingFace client
92-
client = HuggingFaceClient()
97+
# Create HuggingFace client with settings
98+
client = HuggingFaceClient(settings=settings)
9399

94100
# Fetch models from all endpoints
95101
console.print("[blue]Fetching models from HuggingFace...[/blue]")
96-
models = client.list_models_from_endpoints(settings.get_all_endpoints())
102+
103+
if all:
104+
# Show all models, but fetch detailed info for MONAI Bundles to get accurate extension data
105+
console.print("[yellow]Note: Fetching detailed info for MONAI Bundles to show accurate extension data[/yellow]")
106+
models = client.list_models_from_endpoints(settings.get_all_endpoints(), fetch_details_for_bundles=True)
107+
else:
108+
# Show only models with TorchScript (.ts) files by default
109+
models = client.list_torchscript_models(settings.get_all_endpoints())
97110

98111
# Filter for bundles if requested
99112
if bundles_only:
100113
models = [m for m in models if m.is_monai_bundle]
101114

102-
# Filter for tested models if requested
115+
# Filter for verified models if requested
103116
if tested_only:
104-
models = [m for m in models if m.model_id in tested_models]
117+
models = [m for m in models if m.model_id in verified_models]
105118

106119
# Sort models by name
107120
models.sort(key=lambda m: m.display_name)
108121

109122
# Display results based on format
110123
if format == "table":
111-
_display_table(models, tested_models)
124+
_display_table(models, verified_models)
112125
elif format == "simple":
113-
_display_simple(models, tested_models)
126+
_display_simple(models, verified_models)
114127
elif format == "json":
115-
_display_json(models, tested_models)
128+
_display_json(models, verified_models)
116129

117130
# Summary
118131
bundle_count = sum(1 for m in models if m.is_monai_bundle)
119-
tested_count = sum(1 for m in models if m.model_id in tested_models)
132+
verified_count = sum(1 for m in models if m.model_id in verified_models)
120133
console.print(
121-
f"\n[green]Total models: {len(models)} (MONAI Bundles: {bundle_count}, Verified: {tested_count})[/green]"
134+
f"\n[green]Total models: {len(models)} (MONAI Bundles: {bundle_count}, Verified: {verified_count})[/green]"
122135
)
123136

124137

@@ -225,28 +238,37 @@ def gen(
225238
raise click.Abort() from e
226239

227240

228-
def _display_table(models: List[ModelInfo], tested_models: Set[str]) -> None:
241+
def _display_table(models: List[ModelInfo], verified_models: Set[str]) -> None:
229242
"""Display models in a rich table format.
230243
231244
Args:
232245
models: List of ModelInfo objects to display
233-
tested_models: Set of tested model IDs
246+
verified_models: Set of verified model IDs
234247
"""
235248
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
236249
table.add_column("Model ID", style="cyan", width=40)
237250
table.add_column("Name", style="white")
238-
table.add_column("Type", style="green")
251+
table.add_column("MONAI Bundle", style="green")
239252
table.add_column("Status", style="blue", width=10)
240253
table.add_column("Downloads", justify="right", style="yellow")
241254
table.add_column("Likes", justify="right", style="red")
242255

243256
for model in models:
244-
model_type = "[green]MONAI Bundle[/green]" if model.is_monai_bundle else "Model"
245-
status = "[bold green]✓ Verified[/bold green]" if model.model_id in tested_models else ""
257+
# MONAI Bundle column logic: "Yes" if is_monai_bundle (has .ts), "No (extension)" otherwise
258+
if model.is_monai_bundle:
259+
bundle_status = "[green]✓ Yes[/green]"
260+
else:
261+
primary_ext = model.primary_extension
262+
if primary_ext:
263+
bundle_status = f"[dim]✗ No ({primary_ext})[/dim]"
264+
else:
265+
bundle_status = "[dim]✗ No[/dim]"
266+
267+
status = "[bold green]✓ Verified[/bold green]" if model.model_id in verified_models else ""
246268
table.add_row(
247269
model.model_id,
248270
model.display_name,
249-
model_type,
271+
bundle_status,
250272
status,
251273
str(model.downloads or "N/A"),
252274
str(model.likes or "N/A"),
@@ -255,31 +277,31 @@ def _display_table(models: List[ModelInfo], tested_models: Set[str]) -> None:
255277
console.print(table)
256278

257279

258-
def _display_simple(models: List[ModelInfo], tested_models: Set[str]) -> None:
280+
def _display_simple(models: List[ModelInfo], verified_models: Set[str]) -> None:
259281
"""Display models in a simple list format.
260282
261283
Shows each model with emoji indicators:
262284
- 📦 for MONAI Bundle, 📄 for regular model
263-
- ✓ for tested models
285+
- ✓ for verified models
264286
265287
Args:
266288
models: List of ModelInfo objects to display
267-
tested_models: Set of tested model IDs
289+
verified_models: Set of verified model IDs
268290
"""
269291
for model in models:
270292
bundle_marker = "📦" if model.is_monai_bundle else "📄"
271-
tested_marker = " ✓" if model.model_id in tested_models else ""
272-
console.print(f"{bundle_marker} {model.model_id} - {model.display_name}{tested_marker}")
293+
verified_marker = " ✓" if model.model_id in verified_models else ""
294+
console.print(f"{bundle_marker} {model.model_id} - {model.display_name}{verified_marker}")
273295

274296

275-
def _display_json(models: List[ModelInfo], tested_models: Set[str]) -> None:
297+
def _display_json(models: List[ModelInfo], verified_models: Set[str]) -> None:
276298
"""Display models in JSON format.
277299
278300
Outputs a JSON array of model information suitable for programmatic consumption.
279301
280302
Args:
281303
models: List of ModelInfo objects to display
282-
tested_models: Set of tested model IDs
304+
verified_models: Set of verified model IDs
283305
"""
284306
import json
285307

@@ -288,7 +310,10 @@ def _display_json(models: List[ModelInfo], tested_models: Set[str]) -> None:
288310
"model_id": m.model_id,
289311
"name": m.display_name,
290312
"is_monai_bundle": m.is_monai_bundle,
291-
"is_tested": m.model_id in tested_models,
313+
"has_torchscript": m.has_torchscript,
314+
"model_extensions": m.model_extensions,
315+
"primary_extension": m.primary_extension,
316+
"is_verified": m.model_id in verified_models,
292317
"downloads": m.downloads,
293318
"likes": m.likes,
294319
"tags": m.tags,

tools/pipeline-generator/pipeline_generator/config/config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
# Pipeline Generator Configuration
1313

14+
# Supported model file extensions
15+
supported_models:
16+
- ".ts" # TorchScript (MONAI Bundle)
17+
- ".pt" # PyTorch state dict
18+
- ".pth" # PyTorch model
19+
- ".safetensors" # SafeTensor format
20+
1421
# HuggingFace endpoints to scan for MONAI models
1522
endpoints:
1623
- organization: "MONAI"

tools/pipeline-generator/pipeline_generator/config/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class Endpoint(BaseModel):
5454
class Settings(BaseModel):
5555
"""Application settings."""
5656

57+
supported_models: List[str] = Field(
58+
default_factory=lambda: [".ts", ".pt", ".pth", ".safetensors"],
59+
description="Supported model file extensions"
60+
)
5761
endpoints: List[Endpoint] = Field(default_factory=list)
5862
additional_models: List[Endpoint] = Field(default_factory=list)
5963

tools/pipeline-generator/pipeline_generator/core/hub_client.py

Lines changed: 102 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from huggingface_hub import HfApi, list_models, model_info
1818
from huggingface_hub.utils import HfHubHTTPError
1919

20-
from ..config import Endpoint
20+
from ..config import Endpoint, Settings
2121
from .models import ModelInfo
2222

2323
logger = logging.getLogger(__name__)
@@ -26,9 +26,17 @@
2626
class HuggingFaceClient:
2727
"""Client for interacting with HuggingFace Hub."""
2828

29-
def __init__(self) -> None:
30-
"""Initialize the HuggingFace Hub client."""
29+
def __init__(self, settings: Optional[Settings] = None) -> None:
30+
"""Initialize the HuggingFace Hub client.
31+
32+
Args:
33+
settings: Pipeline generator settings containing supported extensions
34+
"""
3135
self.api = HfApi()
36+
self.settings = settings
37+
self.supported_extensions = (
38+
settings.supported_models if settings else [".ts", ".pt", ".pth", ".safetensors"]
39+
)
3240

3341
def list_models_from_organization(self, organization: str) -> List[ModelInfo]:
3442
"""List all models from a HuggingFace organization.
@@ -71,11 +79,12 @@ def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
7179
logger.error(f"Error fetching model {model_id}: {e}")
7280
return None
7381

74-
def list_models_from_endpoints(self, endpoints: List[Endpoint]) -> List[ModelInfo]:
82+
def list_models_from_endpoints(self, endpoints: List[Endpoint], fetch_details_for_bundles: bool = False) -> List[ModelInfo]:
7583
"""List models from all configured endpoints.
7684
7785
Args:
7886
endpoints: List of endpoint configurations
87+
fetch_details_for_bundles: If True, fetch detailed info for potential MONAI Bundles to get accurate extension info
7988
8089
Returns:
8190
List of ModelInfo objects from all endpoints
@@ -87,6 +96,29 @@ def list_models_from_endpoints(self, endpoints: List[Endpoint]) -> List[ModelInf
8796
# List all models from organization
8897
logger.info(f"Fetching models from organization: {endpoint.organization}")
8998
models = self.list_models_from_organization(endpoint.organization)
99+
100+
if fetch_details_for_bundles:
101+
# For MONAI organization, fetch detailed info for all models to get accurate extension data
102+
# This is needed because bulk API doesn't provide file information (siblings=None)
103+
if endpoint.organization == "MONAI":
104+
enhanced_models = []
105+
for model in models:
106+
# Fetch detailed model info to get file extensions and accurate MONAI Bundle detection
107+
detailed_model = self.get_model_info(model.model_id)
108+
enhanced_models.append(detailed_model if detailed_model else model)
109+
models = enhanced_models
110+
else:
111+
# For non-MONAI organizations, only fetch details for models that might be bundles
112+
enhanced_models = []
113+
for model in models:
114+
if any("monai" in tag.lower() for tag in model.tags):
115+
# Fetch detailed model info to get file extensions
116+
detailed_model = self.get_model_info(model.model_id)
117+
enhanced_models.append(detailed_model if detailed_model else model)
118+
else:
119+
enhanced_models.append(model)
120+
models = enhanced_models
121+
90122
all_models.extend(models)
91123

92124
elif endpoint.model_id:
@@ -98,6 +130,66 @@ def list_models_from_endpoints(self, endpoints: List[Endpoint]) -> List[ModelInf
98130

99131
return all_models
100132

133+
def _detect_model_extensions(self, model_data: Any) -> List[str]:
134+
"""Detect model file extensions in a HuggingFace repository.
135+
136+
Args:
137+
model_data: Model data from HuggingFace API
138+
139+
Returns:
140+
List of detected model file extensions
141+
"""
142+
extensions = []
143+
144+
try:
145+
if hasattr(model_data, "siblings") and model_data.siblings is not None:
146+
file_names = [f.rfilename for f in model_data.siblings]
147+
for filename in file_names:
148+
for ext in self.supported_extensions:
149+
if filename.endswith(ext):
150+
if ext not in extensions:
151+
extensions.append(ext)
152+
except Exception as e:
153+
logger.debug(f"Could not detect extensions for {getattr(model_data, 'modelId', 'unknown')}: {e}")
154+
155+
return extensions
156+
157+
def list_torchscript_models(self, endpoints: List[Endpoint]) -> List[ModelInfo]:
158+
"""List models that have TorchScript (.ts) files.
159+
160+
This method fetches detailed information for each model individually to
161+
check for TorchScript files, which is slower than bulk listing but accurate.
162+
163+
Args:
164+
endpoints: List of endpoint configurations
165+
166+
Returns:
167+
List of ModelInfo objects that contain .ts files
168+
"""
169+
torchscript_models = []
170+
171+
for endpoint in endpoints:
172+
if endpoint.organization:
173+
# List all models from organization first (bulk)
174+
logger.info(f"Checking TorchScript models from organization: {endpoint.organization}")
175+
try:
176+
for model in list_models(author=endpoint.organization):
177+
# Fetch detailed model info to get file information
178+
detailed_model = self.get_model_info(model.modelId)
179+
if detailed_model and detailed_model.has_torchscript:
180+
torchscript_models.append(detailed_model)
181+
except Exception as e:
182+
logger.error(f"Error checking TorchScript models from {endpoint.organization}: {e}")
183+
184+
elif endpoint.model_id:
185+
# Get specific model
186+
logger.info(f"Checking TorchScript model: {endpoint.model_id}")
187+
model = self.get_model_info(endpoint.model_id)
188+
if model and model.has_torchscript:
189+
torchscript_models.append(model)
190+
191+
return torchscript_models
192+
101193
def _extract_model_info(self, model_data: Any) -> ModelInfo:
102194
"""Extract ModelInfo from HuggingFace model data.
103195
@@ -107,23 +199,14 @@ def _extract_model_info(self, model_data: Any) -> ModelInfo:
107199
Returns:
108200
ModelInfo object
109201
"""
110-
# Check if this is a MONAI Bundle
111-
is_monai_bundle = False
202+
# Detect model extensions
203+
model_extensions = self._detect_model_extensions(model_data)
204+
205+
# Check if this is a MONAI Bundle - defined as having TorchScript (.ts) files
206+
is_monai_bundle = ".ts" in model_extensions
112207
bundle_metadata = None
113208

114-
# Check tags for MONAI-related tags
115209
tags = getattr(model_data, "tags", [])
116-
if any("monai" in tag.lower() for tag in tags):
117-
is_monai_bundle = True
118-
119-
# Check if metadata.json exists in the model files
120-
try:
121-
if hasattr(model_data, "siblings"):
122-
file_names = [f.rfilename for f in model_data.siblings]
123-
if any("metadata.json" in f for f in file_names):
124-
is_monai_bundle = True
125-
except Exception:
126-
pass
127210

128211
# Extract description from cardData if available
129212
description = None
@@ -145,4 +228,5 @@ def _extract_model_info(self, model_data: Any) -> ModelInfo:
145228
tags=tags,
146229
is_monai_bundle=is_monai_bundle,
147230
bundle_metadata=bundle_metadata,
231+
model_extensions=model_extensions,
148232
)

0 commit comments

Comments
 (0)