Skip to content

Commit ba39c39

Browse files
Soulteranka-afk
andauthored
perf: enhance provider management with reload locking and logging (#3793)
- Introduced a reload lock to prevent concurrent reloads of providers. - Added logging to indicate when a provider is disabled and when providers are being synchronized with the configuration. - Refactored the reload method to improve clarity and maintainability. Co-authored-by: anka <[email protected]>
1 parent 6a50d31 commit ba39c39

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

astrbot/core/provider/manager.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import traceback
33

4-
from astrbot.core import logger, sp
4+
from astrbot.core import astrbot_config, logger, sp
55
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
66
from astrbot.core.db import BaseDatabase
77

@@ -24,6 +24,7 @@ def __init__(
2424
db_helper: BaseDatabase,
2525
persona_mgr: PersonaManager,
2626
):
27+
self.reload_lock = asyncio.Lock()
2728
self.persona_mgr = persona_mgr
2829
self.acm = acm
2930
config = acm.confs["default"]
@@ -226,6 +227,7 @@ async def initialize(self):
226227

227228
async def load_provider(self, provider_config: dict):
228229
if not provider_config["enable"]:
230+
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
229231
return
230232
if provider_config.get("provider_type", "") == "agent_runner":
231233
return
@@ -434,40 +436,46 @@ async def load_provider(self, provider_config: dict):
434436
)
435437

436438
async def reload(self, provider_config: dict):
437-
await self.terminate_provider(provider_config["id"])
438-
if provider_config["enable"]:
439-
await self.load_provider(provider_config)
440-
441-
# 和配置文件保持同步
442-
config_ids = [provider["id"] for provider in self.providers_config]
443-
logger.debug(f"providers in user's config: {config_ids}")
444-
for key in list(self.inst_map.keys()):
445-
if key not in config_ids:
446-
await self.terminate_provider(key)
447-
448-
if len(self.provider_insts) == 0:
449-
self.curr_provider_inst = None
450-
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
451-
self.curr_provider_inst = self.provider_insts[0]
452-
logger.info(
453-
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
454-
)
439+
async with self.reload_lock:
440+
await self.terminate_provider(provider_config["id"])
441+
if provider_config["enable"]:
442+
await self.load_provider(provider_config)
455443

456-
if len(self.stt_provider_insts) == 0:
457-
self.curr_stt_provider_inst = None
458-
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
459-
self.curr_stt_provider_inst = self.stt_provider_insts[0]
460-
logger.info(
461-
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
462-
)
444+
# 和配置文件保持同步
445+
self.providers_config = astrbot_config["provider"]
446+
config_ids = [provider["id"] for provider in self.providers_config]
447+
logger.info(f"providers in user's config: {config_ids}")
448+
for key in list(self.inst_map.keys()):
449+
if key not in config_ids:
450+
await self.terminate_provider(key)
463451

464-
if len(self.tts_provider_insts) == 0:
465-
self.curr_tts_provider_inst = None
466-
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
467-
self.curr_tts_provider_inst = self.tts_provider_insts[0]
468-
logger.info(
469-
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
470-
)
452+
if len(self.provider_insts) == 0:
453+
self.curr_provider_inst = None
454+
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
455+
self.curr_provider_inst = self.provider_insts[0]
456+
logger.info(
457+
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
458+
)
459+
460+
if len(self.stt_provider_insts) == 0:
461+
self.curr_stt_provider_inst = None
462+
elif (
463+
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
464+
):
465+
self.curr_stt_provider_inst = self.stt_provider_insts[0]
466+
logger.info(
467+
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
468+
)
469+
470+
if len(self.tts_provider_insts) == 0:
471+
self.curr_tts_provider_inst = None
472+
elif (
473+
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
474+
):
475+
self.curr_tts_provider_inst = self.tts_provider_insts[0]
476+
logger.info(
477+
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
478+
)
471479

472480
def get_insts(self):
473481
return self.provider_insts

0 commit comments

Comments
 (0)