Skip to content

Commit ca4b8f3

Browse files
authored
Cleanup empty dir if frontend zip download failed (Comfy-Org#4574)
1 parent 70b8405 commit ca4b8f3

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

app/frontend_management.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataclasses import dataclass
99
from functools import cached_property
1010
from pathlib import Path
11-
from typing import TypedDict
11+
from typing import TypedDict, Optional
1212

1313
import requests
1414
from typing_extensions import NotRequired
@@ -132,12 +132,13 @@ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
132132
return match_result.group(1), match_result.group(2), match_result.group(3)
133133

134134
@classmethod
135-
def init_frontend_unsafe(cls, version_string: str) -> str:
135+
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
136136
"""
137137
Initializes the frontend for the specified version.
138138
139139
Args:
140140
version_string (str): The version string.
141+
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
141142
142143
Returns:
143144
str: The path to the initialized frontend.
@@ -150,23 +151,29 @@ def init_frontend_unsafe(cls, version_string: str) -> str:
150151
return cls.DEFAULT_FRONTEND_PATH
151152

152153
repo_owner, repo_name, version = cls.parse_version_string(version_string)
153-
provider = FrontEndProvider(repo_owner, repo_name)
154+
provider = provider or FrontEndProvider(repo_owner, repo_name)
154155
release = provider.get_release(version)
155156

156157
semantic_version = release["tag_name"].lstrip("v")
157158
web_root = str(
158159
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
159160
)
160161
if not os.path.exists(web_root):
161-
os.makedirs(web_root, exist_ok=True)
162-
logging.info(
163-
"Downloading frontend(%s) version(%s) to (%s)",
164-
provider.folder_name,
165-
semantic_version,
166-
web_root,
167-
)
168-
logging.debug(release)
169-
download_release_asset_zip(release, destination_path=web_root)
162+
try:
163+
os.makedirs(web_root, exist_ok=True)
164+
logging.info(
165+
"Downloading frontend(%s) version(%s) to (%s)",
166+
provider.folder_name,
167+
semantic_version,
168+
web_root,
169+
)
170+
logging.debug(release)
171+
download_release_asset_zip(release, destination_path=web_root)
172+
finally:
173+
# Clean up the directory if it is empty, i.e. the download failed
174+
if not os.listdir(web_root):
175+
os.rmdir(web_root)
176+
170177
return web_root
171178

172179
@classmethod

tests-unit/app_test/frontend_manager_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import pytest
33
from requests.exceptions import HTTPError
4+
from unittest.mock import patch
45

56
from app.frontend_management import (
67
FrontendManager,
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
8384
with pytest.raises(HTTPError):
8485
FrontendManager.init_frontend_unsafe(version_string)
8586

87+
@pytest.fixture
88+
def mock_os_functions():
89+
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
90+
patch('app.frontend_management.os.listdir') as mock_listdir, \
91+
patch('app.frontend_management.os.rmdir') as mock_rmdir:
92+
mock_listdir.return_value = [] # Simulate empty directory
93+
yield mock_makedirs, mock_listdir, mock_rmdir
94+
95+
@pytest.fixture
96+
def mock_download():
97+
with patch('app.frontend_management.download_release_asset_zip') as mock:
98+
mock.side_effect = Exception("Download failed") # Simulate download failure
99+
yield mock
100+
101+
def test_finally_block(mock_os_functions, mock_download, mock_provider):
102+
# Arrange
103+
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
104+
version_string = 'test-owner/[email protected]'
105+
106+
# Act & Assert
107+
with pytest.raises(Exception):
108+
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
109+
110+
# Assert
111+
mock_makedirs.assert_called_once()
112+
mock_download.assert_called_once()
113+
mock_listdir.assert_called_once()
114+
mock_rmdir.assert_called_once()
115+
86116

87117
def test_parse_version_string():
88118
version_string = "owner/[email protected]"

0 commit comments

Comments
 (0)