Skip to content

Commit 89de68a

Browse files
authored
Merge pull request NVIDIA#281 from sbak5/sbak/inject_fault
Move inject_fault from inprocess.tools to shared_utils
2 parents efd1c9c + fbf4df5 commit 89de68a

File tree

4 files changed

+6
-23
lines changed

4 files changed

+6
-23
lines changed

src/nvidia_resiliency_ext/inprocess/tools/__init__.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/nvidia_resiliency_ext/inprocess/tools/inject_fault.py renamed to src/nvidia_resiliency_ext/shared_utils/inject_fault.py

File renamed without changes.

tests/inprocess/app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from packaging import version
3939

4040
import nvidia_resiliency_ext.inprocess as inprocess
41-
import nvidia_resiliency_ext.inprocess.tools as tools
41+
from nvidia_resiliency_ext.shared_utils import inject_fault
4242

4343
os.environ['MASTER_ADDR'] = 'localhost'
4444
os.environ['MASTER_PORT'] = '29500'
@@ -111,7 +111,7 @@ def parse_args(namespace=None, allow_extras=True):
111111
train.add_argument('--ext-max-delay', default=2, type=float)
112112
train.add_argument(
113113
'--ext-fault',
114-
type=lambda s: tools.inject_fault.Fault[s.upper()],
114+
type=lambda s: inject_fault.Fault[s.upper()],
115115
default=None,
116116
nargs='+',
117117
)
@@ -375,7 +375,7 @@ def train(
375375
maybe_trigger_fault(rank, world_size, args)
376376

377377
if args.ext_fault is not None:
378-
tools.inject_fault.inject_fault(
378+
inject_fault.inject_fault(
379379
faults=args.ext_fault,
380380
num_faults=(args.min_faults, args.max_faults),
381381
keep_alive=args.keep_alive,

tests/inprocess/test_app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from datetime import timedelta
2222

2323
import nvidia_resiliency_ext.inprocess as inprocess
24-
import nvidia_resiliency_ext.inprocess.tools as tools
24+
from nvidia_resiliency_ext.shared_utils import inject_fault
2525

2626
from . import app, common
2727

@@ -156,8 +156,8 @@ def train_kwargs():
156156
@common.parametrize(
157157
'ext_fault',
158158
[
159-
(tools.inject_fault.Fault.SIGKILL,),
160-
(tools.inject_fault.Fault.SIGTERM,),
159+
(inject_fault.Fault.SIGKILL,),
160+
(inject_fault.Fault.SIGTERM,),
161161
],
162162
)
163163
def test_wo_exitcode(self, ext_fault):

0 commit comments

Comments
 (0)