Skip to content

Commit 3fdf079

Browse files
committed
优化:API 请求超时和并发限制
1 parent aac690a commit 3fdf079

File tree

1 file changed

+50
-34
lines changed

1 file changed

+50
-34
lines changed

src/core/model_manager.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import aiofiles
88
from urllib.parse import urlparse
99
import time
10+
import asyncio
11+
from aiohttp import ClientTimeout
1012

1113
from src.utils.hash_utils import HashUtils
1214

@@ -20,6 +22,9 @@ def __init__(self, config_file="config.json"):
2022
self.images_path = Path("static/images") # 添加图片保存路径
2123
self.images_path.mkdir(parents=True, exist_ok=True) # 确保目录存在
2224
self.hash_utils = HashUtils()
25+
# 添加并发限制和超时设置
26+
self.semaphore = asyncio.Semaphore(5) # 限制并发请求数
27+
self.timeout = ClientTimeout(total=10) # 10秒超时
2328

2429
def load_config(self) -> dict:
2530
"""加载配置文件"""
@@ -86,32 +91,34 @@ async def scan_models(self):
8691

8792
async def fetch_model_info(self, model_hash, file_path, mtime: float):
8893
"""从Civitai API获取模型信息并下载预览图"""
89-
try:
90-
response = requests.get(f"{self.api_base_url}/model-versions/by-hash/{model_hash}")
91-
if response.status_code == 200:
92-
model_info = response.json()
93-
94-
# 下载预览图
95-
preview_url = model_info.get("images", [{}])[0].get("url")
96-
if preview_url:
97-
local_preview = await self.download_image(preview_url)
98-
if local_preview:
99-
model_info = {
100-
**model_info,
101-
"local_preview": local_preview,
102-
"mtime": mtime, # 记录文件修改时间
103-
"scan_time": time.time() # 记录扫描时间
104-
}
105-
106-
self.models_info[str(file_path)] = {
107-
"hash": model_hash,
108-
"info": model_info
109-
}
110-
print(f"成功获取模型信息: {file_path.name}")
111-
else:
112-
print(f"无法获取模型信息: {file_path.name}, 状态码: {response.status_code}")
113-
except Exception as e:
114-
print(f"获取模型信息时出错: {file_path.name}, 错误: {str(e)}")
94+
async with self.semaphore: # 使用信号量限制并发
95+
try:
96+
async with aiohttp.ClientSession(timeout=self.timeout) as session:
97+
async with session.get(f"{self.api_base_url}/model-versions/by-hash/{model_hash}") as response:
98+
if response.status == 200:
99+
model_info = await response.json()
100+
101+
# 下载预览图
102+
preview_url = model_info.get("images", [{}])[0].get("url")
103+
if preview_url:
104+
local_preview = await self.download_image(preview_url)
105+
if local_preview:
106+
model_info = {
107+
**model_info,
108+
"local_preview": local_preview,
109+
"mtime": mtime, # 记录文件修改时间
110+
"scan_time": time.time() # 记录扫描时间
111+
}
112+
113+
self.models_info[str(file_path)] = {
114+
"hash": model_hash,
115+
"info": model_info
116+
}
117+
print(f"成功获取模型信息: {file_path.name}")
118+
else:
119+
print(f"无法获取模型信息: {file_path.name}, 状态码: {response.status}")
120+
except Exception as e:
121+
print(f"获取模型信息时出错: {file_path.name}, 错误: {str(e)}")
115122

116123
def save_models_info(self, output_file="models_info.json"):
117124
"""保存模型信息到JSON文件"""
@@ -161,18 +168,27 @@ async def download_image(self, url: str) -> str:
161168
if local_path.exists():
162169
return f"/static/images/{filename}"
163170

171+
max_retries = 3
172+
retry_delay = 1
173+
164174
try:
165-
async with aiohttp.ClientSession() as session:
166-
async with session.get(url) as response:
167-
if response.status == 200:
168-
async with aiofiles.open(local_path, 'wb') as f:
169-
await f.write(await response.read())
170-
return f"/static/images/{filename}"
175+
for attempt in range(max_retries):
176+
try:
177+
async with aiohttp.ClientSession(timeout=self.timeout) as session:
178+
async with session.get(url) as response:
179+
if response.status == 200:
180+
async with aiofiles.open(local_path, 'wb') as f:
181+
await f.write(await response.read())
182+
return f"/static/images/{filename}"
183+
break # 如果成功就跳出重试循环
184+
except Exception as e:
185+
if attempt < max_retries - 1: # 如果不是最后一次尝试
186+
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
187+
continue
188+
raise # 最后一次尝试失败,抛出异常
171189
except Exception as e:
172190
print(f"下载图片失败: {url}, 错误: {str(e)}")
173191
return None
174-
175-
return None
176192

177193
def get_model_display_info(self, model_path: str) -> dict:
178194
"""获取用于显示的模型信息"""

0 commit comments

Comments
 (0)