Skip to content

Commit 8b656d2

Browse files
authored
Merge branch 'master' into optimize_logging
2 parents d8a94eb + 6ee822c commit 8b656d2

31 files changed

+1077
-49
lines changed

dlrover/python/common/comm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 The DLRover Authors. All rights reserved.
1+
# Copyright 2026 The DLRover Authors. All rights reserved.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at

dlrover/python/common/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ class NodeEnv(object):
352352
# grpc env
353353
MASTER_CLIENT_TIMEOUT = "MASTER_CLIENT_TIMEOUT"
354354

355+
# extension env
356+
DLROVER_EXTENSION_DYNAMIC_FAILOVER = "DLROVER_EXTENSION_DYNAMIC_FAILOVER"
357+
355358

356359
class DatasetType(object):
357360
TEXT = "text"

dlrover/python/common/enums.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 The EasyDL Authors. All rights reserved.
1+
# Copyright 2026 The DLRover Authors. All rights reserved.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -17,3 +17,10 @@
1717
class ResourceType(Enum):
1818
CPU = "CPU"
1919
GPU = "GPU"
20+
21+
22+
class FailoverStrategy(Enum):
23+
NORMAL_FAILOVER = "NORMAL_FAILOVER"
24+
NODE_FAILOVER = "NODE_FAILOVER"
25+
GLOBAL_FAILOVER = "GLOBAL_FAILOVER"
26+
ABORTION_FAILOVER = "ABORTION_FAILOVER"

dlrover/python/common/failover.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2026 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
import time
14+
from abc import ABC, abstractmethod
15+
from dataclasses import dataclass, field
16+
from typing import Any
17+
18+
from dlrover.python.common.enums import FailoverStrategy
19+
20+
21+
USER_FAILOVER_TRIGGER_JOB_ABORTION = "USER_FAILOVER_TRIGGER_JOB_ABORTION"
22+
USER_FAILOVER_TRIGGER_JOB_RESTART = "USER_FAILOVER_TRIGGER_JOB_RESTART"
23+
24+
25+
@dataclass
26+
class FailureInfo(object):
27+
timestamp: int = int(time.time())
28+
log_content: str = ""
29+
extra_info: dict = field(default_factory=dict)
30+
31+
32+
class DynamicFailoverExtension(ABC):
33+
"""
34+
Dynamic extension for fault-tolerance execution.
35+
"""
36+
37+
@abstractmethod
38+
def get_user_failover_strategy(
39+
self, failure_info: Any
40+
) -> FailoverStrategy:
41+
"""
42+
The user-side implementation to specify a failover-strategy to DLRover
43+
according to the failure info of a process. Defaults to returning
44+
FailoverStrategy.NORMAL_FAILOVER, which employs DLRover's internal logic.
45+
46+
This implementation can be based on simple rule definitions using error
47+
codes or complex logic calls involving external services or model inference.
48+
49+
Args:
50+
failure_info (Any): The basic context when failure happens.
51+
52+
Returns:
53+
FailoverStrategy: The failover strategy.
54+
"""
55+
56+
return FailoverStrategy.NORMAL_FAILOVER

dlrover/python/common/global_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def __init__(self):
148148
self.max_relaunch_count = DefaultValues.MAX_RELAUNCH_COUNT
149149
self.max_group_relaunch_count = DefaultValues.MAX_GROUP_RELAUNCH_COUNT
150150
self.training_elastic_mode = DefaultValues.TRAINING_ELASTIC_MODE
151+
# extensions
152+
self.dynamic_failover_extension = None
151153

152154
def set_params_from_brain(self):
153155
self.train_speed_record_num = self.get_param_value_from_brain(

dlrover/python/diagnosis/common/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,18 @@ class DiagnosisActionType(object):
7575

7676
# master operation
7777
JOB_ABORT = "job_abortion"
78+
JOB_RESTART = "job_restart"
7879
MASTER_RELAUNCH_WORKER = "master_relaunch_worker"
7980
EVENT = "event"
8081

8182
# node operation
8283
RESTART_WORKER = "restart_worker"
8384
RELAUNCH_WORKER = "relaunch_worker"
8485

86+
# job operation
87+
RESTART_JOB = "restart_job"
88+
ABORT_JOB = "abort_job"
89+
8590

8691
class DiagnosisResult(object):
8792
# diag invalid param

dlrover/python/diagnosis/common/diagnosis_action.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,16 @@ def __repr__(self):
249249
)
250250

251251

252-
class JobAbortionAction(DiagnosisAction):
252+
class JobAction(DiagnosisAction):
253253
def __init__(
254254
self,
255+
action_type: str,
255256
reason: str = "",
256257
msg: str = "",
257258
**kwargs,
258259
):
259260
super().__init__(
260-
DiagnosisActionType.JOB_ABORT,
261+
action_type,
261262
DiagnosisConstant.MASTER_INSTANCE,
262263
0,
263264
0,
@@ -284,6 +285,34 @@ def __repr__(self):
284285
)
285286

286287

288+
class JobAbortionAction(JobAction):
289+
def __init__(
290+
self,
291+
reason: str = "",
292+
msg: str = "",
293+
**kwargs,
294+
):
295+
super().__init__(
296+
action_type=DiagnosisActionType.JOB_ABORT,
297+
reason=reason,
298+
msg=msg,
299+
)
300+
301+
302+
class JobRestartAction(JobAction):
303+
def __init__(
304+
self,
305+
reason: str = "",
306+
msg: str = "",
307+
**kwargs,
308+
):
309+
super().__init__(
310+
action_type=DiagnosisActionType.JOB_RESTART,
311+
reason=reason,
312+
msg=msg,
313+
)
314+
315+
287316
def is_same_action(action1: DiagnosisAction, action2: DiagnosisAction) -> bool:
288317
if isinstance(action1, EventAction) and isinstance(action2, EventAction):
289318
action1.__class__ = EventAction

dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 The DLRover Authors. All rights reserved.
1+
# Copyright 2026 The DLRover Authors. All rights reserved.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -18,7 +18,12 @@
1818

1919
from dlrover.python.common import env_utils
2020
from dlrover.python.common.constants import TrainingExceptionLevel
21+
from dlrover.python.common.enums import FailoverStrategy
2122
from dlrover.python.common.error import ProcessError
23+
from dlrover.python.common.failover import (
24+
USER_FAILOVER_TRIGGER_JOB_ABORTION,
25+
USER_FAILOVER_TRIGGER_JOB_RESTART,
26+
)
2227
from dlrover.python.common.log import default_logger as logger
2328
from dlrover.python.common.singleton import Singleton
2429
from dlrover.python.diagnosis.common.constants import (
@@ -30,6 +35,8 @@
3035
from dlrover.python.diagnosis.common.diagnosis_action import (
3136
DiagnosisAction,
3237
NodeAction,
38+
JobAbortionAction,
39+
JobRestartAction,
3340
)
3441
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
3542
from dlrover.python.diagnosis.common.diagnosis_manager import DiagnosisManager
@@ -50,6 +57,10 @@
5057
)
5158
from dlrover.python.elastic_agent.context import get_agent_context
5259
from dlrover.python.elastic_agent.master_client import MasterClient
60+
from dlrover.python.elastic_agent.torch.dynamic_failover import (
61+
DynamicAgentFailoverExtension,
62+
AgentFailureInfo,
63+
)
5364
from dlrover.python.training_event.config import is_dlrover_event_enabled
5465

5566

@@ -60,12 +71,16 @@ def __init__(
6071
errors="",
6172
node_rank=-1,
6273
local_world_size=0,
74+
dynamic_failover_extension=None,
6375
):
6476
self._client = MasterClient.singleton_instance()
6577
self._training_log_file = training_log_file
6678
self._errors = errors
6779
self._stopped = False
6880
self._agent_context = get_agent_context()
81+
self._extension: DynamicAgentFailoverExtension = (
82+
dynamic_failover_extension
83+
)
6984

7085
DiagnosisManager.__init__(self, self._agent_context)
7186

@@ -140,6 +155,71 @@ def diagnose_training_failure(self) -> DiagnosisAction:
140155
self._agent_context.run_result.failures,
141156
self._agent_context.restart_count,
142157
)
158+
159+
def serialize_failures(failures: dict):
160+
try:
161+
str_result = json.dumps(failures)
162+
except Exception:
163+
str_result = str(failures)
164+
return str_result
165+
166+
failure_info = AgentFailureInfo(
167+
node_rank=self._node_rank,
168+
log_content=serialize_failures(
169+
self._agent_context.run_result.failures
170+
),
171+
)
172+
173+
if self._extension is not None:
174+
extension_cls_info = self._extension.__class__
175+
176+
try:
177+
# user strategy
178+
user_strategy = self._extension.get_user_failover_strategy(
179+
failure_info
180+
)
181+
except Exception as e:
182+
logger.warning(
183+
f"Failed to get user_strategy from extension: {extension_cls_info} "
184+
f"by exception: {e}. Use default dlrover failover processing."
185+
)
186+
user_strategy = FailoverStrategy.NORMAL_FAILOVER
187+
188+
if user_strategy == FailoverStrategy.NODE_FAILOVER:
189+
logger.info(
190+
f"[{self._agent_context.worker_spec.role}] Worker group "
191+
f"{self._agent_context.run_result.state.name}, "
192+
f"will relaunch node by user strategy: {extension_cls_info}."
193+
)
194+
return NodeAction(
195+
node_id=env_utils.get_node_id(),
196+
node_type=env_utils.get_node_type(),
197+
instance=DiagnosisConstant.LOCAL_INSTANCE,
198+
action_type=DiagnosisActionType.RELAUNCH_WORKER,
199+
)
200+
elif user_strategy == FailoverStrategy.ABORTION_FAILOVER:
201+
logger.info(
202+
f"[{self._agent_context.worker_spec.role}] Worker group "
203+
f"{self._agent_context.run_result.state.name}, "
204+
f"will abort job by user strategy: {extension_cls_info}."
205+
)
206+
return JobAbortionAction(
207+
reason=USER_FAILOVER_TRIGGER_JOB_ABORTION
208+
)
209+
elif user_strategy == FailoverStrategy.GLOBAL_FAILOVER:
210+
logger.info(
211+
f"[{self._agent_context.worker_spec.role}] Worker group "
212+
f"{self._agent_context.run_result.state.name}, "
213+
f"will relaunch job by user strategy: {extension_cls_info}."
214+
)
215+
return JobRestartAction(
216+
reason=USER_FAILOVER_TRIGGER_JOB_RESTART
217+
)
218+
else:
219+
# FailoverStrategy.NORMAL_FAILOVER: continue with dlrover default logic
220+
pass
221+
222+
# dlrover default logic
143223
ob = self.observe(
144224
DiagnosticianType.NODE_FAILURE,
145225
log_file=self._training_log_file,

dlrover/python/elastic_agent/master_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,13 @@ def set_rdzv_blocked(self, blocked, reason=""):
519519
message = comm.RdzvBlocked(blocked=blocked, reason=reason)
520520
self._report(message)
521521

522+
def report_action(self, action: DiagnosisAction):
523+
message = comm.DiagnosisAction(
524+
action_cls=action.__class__.__name__,
525+
action_content=action.to_json(),
526+
)
527+
self._report(message)
528+
522529
@classmethod
523530
def singleton_instance(cls, *args, **kwargs):
524531
if not cls._instance:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2026 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
from abc import abstractmethod
14+
from dataclasses import dataclass
15+
16+
from dlrover.python.common.enums import FailoverStrategy
17+
from dlrover.python.common.failover import (
18+
DynamicFailoverExtension,
19+
FailureInfo,
20+
)
21+
22+
23+
@dataclass
24+
class AgentFailureInfo(FailureInfo):
25+
node_rank: int = -1
26+
27+
28+
class DynamicAgentFailoverExtension(DynamicFailoverExtension):
29+
"""
30+
Dynamic extension for agent(elastic agent) fault-tolerance execution.
31+
"""
32+
33+
@abstractmethod
34+
def get_user_failover_strategy(
35+
self, failure_info: AgentFailureInfo
36+
) -> FailoverStrategy:
37+
"""
38+
The user-side implementation to specify a failover-strategy to DLRover
39+
according to the failure info of a process. Defaults to returning
40+
FailoverStrategy.NORMAL_FAILOVER, which employs DLRover's internal logic.
41+
42+
This implementation can be based on simple rule definitions using error
43+
codes or complex logic calls involving external services or model inference.
44+
45+
Args:
46+
failure_info (AgentFailureInfo): The basic failure context of agent
47+
when failure happens.
48+
49+
Returns:
50+
FailoverStrategy: The failover strategy.
51+
"""
52+
53+
return FailoverStrategy.NORMAL_FAILOVER

0 commit comments

Comments
 (0)