@@ -452,28 +452,35 @@ def _warn_deprecated_aliases(cls, data: Any) -> Any:
452452 @model_validator (mode = "after" )
453453 def _check_companion_services (self ) -> "RunConfig" :
454454 """Ensure required companion services are set for each pipeline mode."""
455+ required_keys = ["api_key" , "model" ]
455456 if isinstance (self .model , PipelineConfig ):
456457 if not self .model .stt :
457458 raise ValueError ("EVA_MODEL__STT is required when using EVA_MODEL__LLM (ASR-LLM-TTS pipeline)." )
458459 if not self .model .tts :
459460 raise ValueError ("EVA_MODEL__TTS is required when using EVA_MODEL__LLM (ASR-LLM-TTS pipeline)." )
460- self ._validate_service_params ("STT" , self .model .stt , self .model .stt_params )
461- self ._validate_service_params ("TTS" , self .model .tts , self .model .tts_params )
461+ self ._validate_service_params ("STT" , self .model .stt , required_keys , self .model .stt_params )
462+ self ._validate_service_params ("TTS" , self .model .tts , required_keys , self .model .tts_params )
462463 elif isinstance (self .model , AudioLLMConfig ):
463464 if not self .model .tts :
464465 raise ValueError ("EVA_MODEL__TTS is required when using EVA_MODEL__AUDIO_LLM (SpeechLM-TTS pipeline)." )
465- self ._validate_service_params ("TTS" , self .model .tts , self .model .tts_params )
466+ self ._validate_service_params ("TTS" , self .model .tts , required_keys , self .model .tts_params )
467+ self ._validate_service_params ("audio_llm" , self .model .audio_llm , required_keys , self .model .audio_llm_params )
468+ elif isinstance (self .model , SpeechToSpeechConfig ):
469+ # api_key is required, some s2s services don't require model
470+ self ._validate_service_params ("S2S" , self .model .s2s , ["api_key" ], self .model .s2s_params )
466471 return self
467472
468473 # Providers that manage their own model/key resolution (e.g. WebSocket-based)
469474 _SKIP_PARAMS_VALIDATION : ClassVar [set [str ]] = {"nvidia" }
470475
471476 @classmethod
472- def _validate_service_params (cls , service : str , provider : str , params : dict [str , Any ]) -> None :
477+ def _validate_service_params (
478+ cls , service : str , provider : str , required_keys : list [str ], params : dict [str , Any ]
479+ ) -> None :
473480 """Validate that STT/TTS params contain required keys."""
474481 if provider .lower () in cls ._SKIP_PARAMS_VALIDATION :
475482 return
476- missing = [key for key in ( "api_key" , "model" ) if key not in params ]
483+ missing = [key for key in required_keys if key not in params ]
477484 if missing :
478485 missing_str = " and " .join (f'"{ k } "' for k in missing )
479486 env_var = f"EVA_MODEL__{ service } _PARAMS"
0 commit comments