1- # Copyright 2024 The DLRover Authors. All rights reserved.
1+ # Copyright 2026 The DLRover Authors. All rights reserved.
22# Licensed under the Apache License, Version 2.0 (the "License");
33# you may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
1818
1919from dlrover .python .common import env_utils
2020from dlrover .python .common .constants import TrainingExceptionLevel
21+ from dlrover .python .common .enums import FailoverStrategy
2122from dlrover .python .common .error import ProcessError
23+ from dlrover .python .common .failover import (
24+ USER_FAILOVER_TRIGGER_JOB_ABORTION ,
25+ USER_FAILOVER_TRIGGER_JOB_RESTART ,
26+ )
2227from dlrover .python .common .log import default_logger as logger
2328from dlrover .python .common .singleton import Singleton
2429from dlrover .python .diagnosis .common .constants import (
3035from dlrover .python .diagnosis .common .diagnosis_action import (
3136 DiagnosisAction ,
3237 NodeAction ,
38+ JobAbortionAction ,
39+ JobRestartAction ,
3340)
3441from dlrover .python .diagnosis .common .diagnosis_data import WorkerTrainingMetric
3542from dlrover .python .diagnosis .common .diagnosis_manager import DiagnosisManager
5057)
5158from dlrover .python .elastic_agent .context import get_agent_context
5259from dlrover .python .elastic_agent .master_client import MasterClient
60+ from dlrover .python .elastic_agent .torch .dynamic_failover import (
61+ DynamicAgentFailoverExtension ,
62+ AgentFailureInfo ,
63+ )
5364from dlrover .python .training_event .config import is_dlrover_event_enabled
5465
5566
@@ -60,12 +71,16 @@ def __init__(
6071 errors = "" ,
6172 node_rank = - 1 ,
6273 local_world_size = 0 ,
74+ dynamic_failover_extension = None ,
6375 ):
6476 self ._client = MasterClient .singleton_instance ()
6577 self ._training_log_file = training_log_file
6678 self ._errors = errors
6779 self ._stopped = False
6880 self ._agent_context = get_agent_context ()
81+ self ._extension : DynamicAgentFailoverExtension = (
82+ dynamic_failover_extension
83+ )
6984
7085 DiagnosisManager .__init__ (self , self ._agent_context )
7186
@@ -140,6 +155,71 @@ def diagnose_training_failure(self) -> DiagnosisAction:
140155 self ._agent_context .run_result .failures ,
141156 self ._agent_context .restart_count ,
142157 )
158+
159+ def serialize_failures (failures : dict ):
160+ try :
161+ str_result = json .dumps (failures )
162+ except Exception :
163+ str_result = str (failures )
164+ return str_result
165+
166+ failure_info = AgentFailureInfo (
167+ node_rank = self ._node_rank ,
168+ log_content = serialize_failures (
169+ self ._agent_context .run_result .failures
170+ ),
171+ )
172+
173+ if self ._extension is not None :
174+ extension_cls_info = self ._extension .__class__
175+
176+ try :
177+ # user strategy
178+ user_strategy = self ._extension .get_user_failover_strategy (
179+ failure_info
180+ )
181+ except Exception as e :
182+ logger .warning (
183+ f"Failed to get user_strategy from extension: { extension_cls_info } "
184+ f"by exception: { e } . Use default dlrover failover processing."
185+ )
186+ user_strategy = FailoverStrategy .NORMAL_FAILOVER
187+
188+ if user_strategy == FailoverStrategy .NODE_FAILOVER :
189+ logger .info (
190+ f"[{ self ._agent_context .worker_spec .role } ] Worker group "
191+ f"{ self ._agent_context .run_result .state .name } , "
192+ f"will relaunch node by user strategy: { extension_cls_info } ."
193+ )
194+ return NodeAction (
195+ node_id = env_utils .get_node_id (),
196+ node_type = env_utils .get_node_type (),
197+ instance = DiagnosisConstant .LOCAL_INSTANCE ,
198+ action_type = DiagnosisActionType .RELAUNCH_WORKER ,
199+ )
200+ elif user_strategy == FailoverStrategy .ABORTION_FAILOVER :
201+ logger .info (
202+ f"[{ self ._agent_context .worker_spec .role } ] Worker group "
203+ f"{ self ._agent_context .run_result .state .name } , "
204+ f"will abort job by user strategy: { extension_cls_info } ."
205+ )
206+ return JobAbortionAction (
207+ reason = USER_FAILOVER_TRIGGER_JOB_ABORTION
208+ )
209+ elif user_strategy == FailoverStrategy .GLOBAL_FAILOVER :
210+ logger .info (
211+ f"[{ self ._agent_context .worker_spec .role } ] Worker group "
212+ f"{ self ._agent_context .run_result .state .name } , "
213+ f"will relaunch job by user strategy: { extension_cls_info } ."
214+ )
215+ return JobRestartAction (
216+ reason = USER_FAILOVER_TRIGGER_JOB_RESTART
217+ )
218+ else :
219+ # FailoverStrategy.NORMAL_FAILOVER: continue with dlrover default logic
220+ pass
221+
222+ # dlrover default logic
143223 ob = self .observe (
144224 DiagnosticianType .NODE_FAILURE ,
145225 log_file = self ._training_log_file ,
0 commit comments