Skip to content

Commit 0086d69

Browse files
committed
Fix formatting
Signed-off-by: Victor Chang <[email protected]>
1 parent 574191a commit 0086d69

File tree

6 files changed

+61
-69
lines changed

6 files changed

+61
-69
lines changed

tools/pipeline-generator/pipeline_generator/cli/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def list(ctx: click.Context, format: str, bundles_only: bool, tested_only: bool,
9999

100100
# Fetch models from all endpoints
101101
console.print("[blue]Fetching models from HuggingFace...[/blue]")
102-
102+
103103
if all:
104104
# Show all models, but fetch detailed info for MONAI Bundles to get accurate extension data
105105
console.print("[yellow]Note: Fetching detailed info for MONAI Bundles to show accurate extension data[/yellow]")
@@ -263,7 +263,7 @@ def _display_table(models: List[ModelInfo], verified_models: Set[str]) -> None:
263263
bundle_status = f"[dim]✗ No ({primary_ext})[/dim]"
264264
else:
265265
bundle_status = "[dim]✗ No[/dim]"
266-
266+
267267
status = "[bold green]✓ Verified[/bold green]" if model.model_id in verified_models else ""
268268
table.add_row(
269269
model.model_id,

tools/pipeline-generator/pipeline_generator/config/settings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ class Settings(BaseModel):
5555
"""Application settings."""
5656

5757
supported_models: List[str] = Field(
58-
default_factory=lambda: [".ts", ".pt", ".pth", ".safetensors"],
59-
description="Supported model file extensions"
58+
default_factory=lambda: [".ts", ".pt", ".pth", ".safetensors"], description="Supported model file extensions"
6059
)
6160
endpoints: List[Endpoint] = Field(default_factory=list)
6261
additional_models: List[Endpoint] = Field(default_factory=list)

tools/pipeline-generator/pipeline_generator/core/hub_client.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@ class HuggingFaceClient:
2828

2929
def __init__(self, settings: Optional[Settings] = None) -> None:
3030
"""Initialize the HuggingFace Hub client.
31-
31+
3232
Args:
3333
settings: Pipeline generator settings containing supported extensions
3434
"""
3535
self.api = HfApi()
3636
self.settings = settings
37-
self.supported_extensions = (
38-
settings.supported_models if settings else [".ts", ".pt", ".pth", ".safetensors"]
39-
)
37+
self.supported_extensions = settings.supported_models if settings else [".ts", ".pt", ".pth", ".safetensors"]
4038

4139
def list_models_from_organization(self, organization: str) -> List[ModelInfo]:
4240
"""List all models from a HuggingFace organization.
@@ -79,7 +77,9 @@ def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
7977
logger.error(f"Error fetching model {model_id}: {e}")
8078
return None
8179

82-
def list_models_from_endpoints(self, endpoints: List[Endpoint], fetch_details_for_bundles: bool = False) -> List[ModelInfo]:
80+
def list_models_from_endpoints(
81+
self, endpoints: List[Endpoint], fetch_details_for_bundles: bool = False
82+
) -> List[ModelInfo]:
8383
"""List models from all configured endpoints.
8484
8585
Args:
@@ -96,7 +96,7 @@ def list_models_from_endpoints(self, endpoints: List[Endpoint], fetch_details_fo
9696
# List all models from organization
9797
logger.info(f"Fetching models from organization: {endpoint.organization}")
9898
models = self.list_models_from_organization(endpoint.organization)
99-
99+
100100
if fetch_details_for_bundles:
101101
# For MONAI organization, fetch detailed info for all models to get accurate extension data
102102
# This is needed because bulk API doesn't provide file information (siblings=None)
@@ -118,7 +118,7 @@ def list_models_from_endpoints(self, endpoints: List[Endpoint], fetch_details_fo
118118
else:
119119
enhanced_models.append(model)
120120
models = enhanced_models
121-
121+
122122
all_models.extend(models)
123123

124124
elif endpoint.model_id:
@@ -140,7 +140,7 @@ def _detect_model_extensions(self, model_data: Any) -> List[str]:
140140
List of detected model file extensions
141141
"""
142142
extensions = []
143-
143+
144144
try:
145145
if hasattr(model_data, "siblings") and model_data.siblings is not None:
146146
file_names = [f.rfilename for f in model_data.siblings]
@@ -151,13 +151,13 @@ def _detect_model_extensions(self, model_data: Any) -> List[str]:
151151
extensions.append(ext)
152152
except Exception as e:
153153
logger.debug(f"Could not detect extensions for {getattr(model_data, 'modelId', 'unknown')}: {e}")
154-
154+
155155
return extensions
156156

157157
def list_torchscript_models(self, endpoints: List[Endpoint]) -> List[ModelInfo]:
158158
"""List models that have TorchScript (.ts) files.
159159
160-
This method fetches detailed information for each model individually to
160+
This method fetches detailed information for each model individually to
161161
check for TorchScript files, which is slower than bulk listing but accurate.
162162
163163
Args:
@@ -167,7 +167,7 @@ def list_torchscript_models(self, endpoints: List[Endpoint]) -> List[ModelInfo]:
167167
List of ModelInfo objects that contain .ts files
168168
"""
169169
torchscript_models = []
170-
170+
171171
for endpoint in endpoints:
172172
if endpoint.organization:
173173
# List all models from organization first (bulk)

tools/pipeline-generator/pipeline_generator/core/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def primary_extension(self) -> Optional[str]:
8080
"""
8181
if not self.model_extensions:
8282
return None
83-
83+
8484
# Prioritize .ts extension
8585
if ".ts" in self.model_extensions:
8686
return ".ts"
87-
87+
8888
return self.model_extensions[0]

tools/pipeline-generator/tests/test_cli.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def test_list_command_tested_only(self, mock_load_config, mock_client_class):
250250

251251
# Mock the list response
252252
test_models = [
253-
ModelInfo(model_id="MONAI/tested_model", name="Tested Model", is_monai_bundle=True, model_extensions=[".ts"]),
253+
ModelInfo(
254+
model_id="MONAI/tested_model", name="Tested Model", is_monai_bundle=True, model_extensions=[".ts"]
255+
),
254256
ModelInfo(
255257
model_id="MONAI/untested_model",
256258
name="Untested Model",
@@ -309,7 +311,7 @@ def test_list_command_all_flag(self, mock_load_config, mock_client_class):
309311
is_monai_bundle=True,
310312
),
311313
ModelInfo(
312-
model_id="MONAI/pt_model",
314+
model_id="MONAI/pt_model",
313315
name="PyTorch Model",
314316
model_extensions=[".pt"],
315317
is_monai_bundle=False,
@@ -331,7 +333,7 @@ def test_list_command_all_flag(self, mock_load_config, mock_client_class):
331333
assert "Verified: 0" in result.output
332334

333335
@patch("pipeline_generator.cli.main.HuggingFaceClient")
334-
@patch("pipeline_generator.cli.main.load_config")
336+
@patch("pipeline_generator.cli.main.load_config")
335337
def test_list_command_default_torchscript_only(self, mock_load_config, mock_client_class):
336338
"""Test list command defaults to torchscript models only."""
337339
# Mock setup
@@ -347,7 +349,7 @@ def test_list_command_default_torchscript_only(self, mock_load_config, mock_clie
347349
test_models = [
348350
ModelInfo(
349351
model_id="MONAI/ts_model",
350-
name="TorchScript Model",
352+
name="TorchScript Model",
351353
model_extensions=[".ts"],
352354
is_monai_bundle=True,
353355
),
@@ -358,10 +360,8 @@ def test_list_command_default_torchscript_only(self, mock_load_config, mock_clie
358360
result = self.runner.invoke(cli, ["list"])
359361

360362
assert result.exit_code == 0
361-
# Should call list_torchscript_models (default behavior)
362-
mock_client.list_torchscript_models.assert_called_once_with(
363-
mock_settings.get_all_endpoints.return_value
364-
)
363+
# Should call list_torchscript_models (default behavior)
364+
mock_client.list_torchscript_models.assert_called_once_with(mock_settings.get_all_endpoints.return_value)
365365
mock_client.list_models_from_endpoints.assert_not_called()
366366
assert "MONAI/ts_model" in result.output
367367
assert "Verified: 0" in result.output
@@ -389,7 +389,7 @@ def test_list_command_monai_bundle_column_logic(self, mock_load_config, mock_cli
389389
),
390390
ModelInfo(
391391
model_id="MONAI/pt_model",
392-
name="PyTorch Model",
392+
name="PyTorch Model",
393393
model_extensions=[".pt"],
394394
is_monai_bundle=False, # Should be False for .pt files
395395
),
@@ -408,17 +408,17 @@ def test_list_command_monai_bundle_column_logic(self, mock_load_config, mock_cli
408408
assert result.exit_code == 0
409409
# Check MONAI Bundle column contents with new display format
410410
output = result.output
411-
411+
412412
# Should show "✓" and "Yes" for .ts model (MONAI Bundle) - may be on separate lines due to table wrapping
413413
assert "MONAI/ts_model" in output
414414
assert "✓" in output # Checkmark emoji
415415
assert "Yes" in output # Text
416-
417-
# Should show "✗" and "No" for .pt model
416+
417+
# Should show "✗" and "No" for .pt model
418418
assert "MONAI/pt_model" in output
419419
assert "✗" in output # X emoji
420420
assert "No" in output # The "No" text should appear
421-
421+
422422
# Should show "✗ No" for model with no extensions
423423
assert "MONAI/no_ext_model" in output
424424
# The Verified count may have color codes, so check for the text parts
@@ -432,10 +432,12 @@ def test_list_command_with_verified_models(self, mock_load_config, mock_client_c
432432
mock_settings = Mock()
433433
mock_settings.get_all_endpoints.return_value = [Mock(organization="MONAI")]
434434
mock_settings.endpoints = [
435-
Mock(models=[
436-
Mock(model_id="MONAI/verified_model1"),
437-
Mock(model_id="MONAI/verified_model2"),
438-
])
435+
Mock(
436+
models=[
437+
Mock(model_id="MONAI/verified_model1"),
438+
Mock(model_id="MONAI/verified_model2"),
439+
]
440+
)
439441
]
440442
mock_load_config.return_value = mock_settings
441443

@@ -466,13 +468,13 @@ def test_list_command_with_verified_models(self, mock_load_config, mock_client_c
466468
assert "MONAI/verified_model1" in result.output
467469
assert "MONAI/unverified_model" in result.output
468470
assert "Verified: 1" in result.output
469-
471+
470472
# Check that verified model shows verification checkmark
471-
output_lines = result.output.split('\n')
473+
output_lines = result.output.split("\n")
472474
verified_line = [line for line in output_lines if "MONAI/verified_model1" in line]
473475
assert any("✓ Verified" in line for line in verified_line)
474-
475-
unverified_line = [line for line in output_lines if "MONAI/unverified_model" in line]
476+
477+
unverified_line = [line for line in output_lines if "MONAI/unverified_model" in line]
476478
assert not any("✓ Verified" in line for line in unverified_line)
477479

478480
@patch("pipeline_generator.cli.main.HuggingFaceClient")
@@ -482,9 +484,7 @@ def test_list_command_json_output(self, mock_load_config, mock_client_class):
482484
# Mock setup
483485
mock_settings = Mock()
484486
mock_settings.get_all_endpoints.return_value = [Mock(organization="MONAI")]
485-
mock_settings.endpoints = [
486-
Mock(models=[Mock(model_id="MONAI/test_model")])
487-
]
487+
mock_settings.endpoints = [Mock(models=[Mock(model_id="MONAI/test_model")])]
488488
mock_load_config.return_value = mock_settings
489489

490490
mock_client = Mock()
@@ -508,17 +508,18 @@ def test_list_command_json_output(self, mock_load_config, mock_client_class):
508508
result = self.runner.invoke(cli, ["list", "--format", "json"])
509509

510510
assert result.exit_code == 0
511-
511+
512512
# Parse JSON output to verify new fields
513513
import json
514-
json_start = result.output.find('[')
515-
json_end = result.output.rfind(']') + 1 # Find the last ] and include it
514+
515+
json_start = result.output.find("[")
516+
json_end = result.output.rfind("]") + 1 # Find the last ] and include it
516517
json_text = result.output[json_start:json_end]
517518
json_data = json.loads(json_text)
518-
519+
519520
assert len(json_data) == 1
520521
model_data = json_data[0]
521-
522+
522523
# Check all new fields are present
523524
assert model_data["model_id"] == "MONAI/test_model"
524525
assert model_data["is_monai_bundle"] is True

tools/pipeline-generator/tests/test_hub_client.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ def test_list_models_from_organization_success(self, mock_list_models):
4545
created_at=datetime(2023, 1, 1),
4646
lastModified=datetime(2023, 12, 1),
4747
tags=["medical", "segmentation"],
48-
siblings=[
49-
Mock(rfilename="configs/metadata.json"),
50-
Mock(rfilename="models/model.ts")
51-
],
48+
siblings=[Mock(rfilename="configs/metadata.json"), Mock(rfilename="models/model.ts")],
5249
)
5350

5451
mock_model2 = SimpleModelData(
@@ -104,10 +101,7 @@ def test_get_model_info_success(self, mock_model_info):
104101
created_at=datetime(2023, 1, 1),
105102
lastModified=datetime(2023, 12, 1),
106103
tags=["medical", "segmentation"],
107-
siblings=[
108-
Mock(rfilename="configs/metadata.json"),
109-
Mock(rfilename="models/model.ts")
110-
],
104+
siblings=[Mock(rfilename="configs/metadata.json"), Mock(rfilename="models/model.ts")],
111105
cardData={"description": "Spleen segmentation model"},
112106
)
113107

@@ -201,10 +195,7 @@ def test_extract_model_info_bundle_detection(self):
201195
assert model.is_monai_bundle is True
202196

203197
# Test without TorchScript (.ts) file - only .pt file
204-
mock_model.siblings = [
205-
Mock(rfilename="configs/metadata.json"),
206-
Mock(rfilename="models/model.pt")
207-
]
198+
mock_model.siblings = [Mock(rfilename="configs/metadata.json"), Mock(rfilename="models/model.pt")]
208199
model = self.client._extract_model_info(mock_model)
209200
assert model.is_monai_bundle is False
210201

@@ -384,11 +375,11 @@ def test_detect_model_extensions_with_torchscript(self):
384375
Mock(rfilename="model.ts"),
385376
Mock(rfilename="config.yaml"),
386377
Mock(rfilename="README.md"),
387-
]
378+
],
388379
)
389380

390381
extensions = self.client._detect_model_extensions(mock_model)
391-
382+
392383
assert ".ts" in extensions
393384
assert len(extensions) == 1
394385

@@ -402,11 +393,11 @@ def test_detect_model_extensions_multiple_formats(self):
402393
Mock(rfilename="model.pt"),
403394
Mock(rfilename="model.safetensors"),
404395
Mock(rfilename="config.yaml"),
405-
]
396+
],
406397
)
407398

408399
extensions = self.client._detect_model_extensions(mock_model)
409-
400+
410401
assert ".ts" in extensions
411402
assert ".pt" in extensions
412403
assert ".safetensors" in extensions
@@ -420,11 +411,11 @@ def test_detect_model_extensions_no_model_files(self):
420411
siblings=[
421412
Mock(rfilename="README.md"),
422413
Mock(rfilename="config.yaml"),
423-
]
414+
],
424415
)
425416

426417
extensions = self.client._detect_model_extensions(mock_model)
427-
418+
428419
assert len(extensions) == 0
429420

430421
def test_detect_model_extensions_no_siblings(self):
@@ -433,7 +424,7 @@ def test_detect_model_extensions_no_siblings(self):
433424
mock_model = SimpleModelData(modelId="MONAI/test_model")
434425

435426
extensions = self.client._detect_model_extensions(mock_model)
436-
427+
437428
assert len(extensions) == 0
438429

439430
@patch("pipeline_generator.core.hub_client.list_models")
@@ -452,7 +443,7 @@ def test_list_torchscript_models_filters_correctly(self, mock_list_models):
452443
)
453444

454445
mock_model_without_ts = SimpleModelData(
455-
modelId="MONAI/model_without_ts",
446+
modelId="MONAI/model_without_ts",
456447
author="MONAI",
457448
downloads=50,
458449
likes=5,
@@ -472,13 +463,14 @@ def mock_get_model_info(model_id):
472463
return self.client._extract_model_info(mock_model_without_ts)
473464
return None
474465

475-
with patch.object(self.client, 'get_model_info', side_effect=mock_get_model_info):
466+
with patch.object(self.client, "get_model_info", side_effect=mock_get_model_info):
476467
from pipeline_generator.config.settings import Endpoint
468+
477469
endpoints = [Endpoint(organization="MONAI")]
478470

479471
# Test the torchscript filtering
480472
torchscript_models = self.client.list_torchscript_models(endpoints)
481-
473+
482474
# Should only return the model with .ts file
483475
assert len(torchscript_models) == 1
484476
assert torchscript_models[0].model_id == "MONAI/model_with_ts"
@@ -502,7 +494,7 @@ def test_extract_model_info_includes_extensions(self):
502494
)
503495

504496
result = self.client._extract_model_info(mock_model)
505-
497+
506498
assert ".ts" in result.model_extensions
507499
assert ".pt" in result.model_extensions
508500
assert result.has_torchscript is True

0 commit comments

Comments
 (0)