Skip to content

Commit 1a57a86

Browse files
RobotSailMaxusmusti
authored andcommitted
chore: add exit code & tox fix
Currently, the training library does not exit when an error is encountered within the training loop (invoked through torchrun). This commit updates that functionality so we correctly return an exit code of 1 on child failure. Additionally, this commit also adds the `make fix` command which automatically fixes all trivial issues picked up on by ruff Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> (cherry picked from commit 9c899dc)
1 parent aabd86b commit 1a57a86

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/instructlab/training/main_ds.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
745745
print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m")
746746
process = None
747747
interrupt: KeyboardInterrupt | Exception | None = None
748+
failure = False
748749
try:
749750
process = StreamablePopen(
750751
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
@@ -755,19 +756,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
755756
print("Training subprocess interrupted by user.")
756757
interrupt = e
757758
except Exception as e:
758-
print(f"An error occurred: {str(e)}")
759+
print("Unexpected exception received during distributed training")
759760
interrupt = e
760761
finally:
761762
if "process" not in locals() or process is None:
762763
return
763-
if process.poll() == 0:
764-
print("\033[92mTraining subprocess exited successfully! 🎉\033[0m")
764+
765+
failure = process.poll() != 0
766+
if not failure:
767+
print("\033[92mOperation completed successfully! 🎉\033[0m")
765768
else:
766769
print(
767770
"\033[91mTraining subprocess has not exited yet. Sending SIGTERM.\033[0m"
768771
)
769772

770-
print("Sending interrupt signal to Training subprocess.")
771773
process.terminate()
772774
try:
773775
print("Waiting for process to exit, 60s...")
@@ -779,8 +781,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
779781
process.kill()
780782

781783
if interrupt:
782-
print(f"Error caught from training subprocess.: {interrupt}")
783784
raise interrupt
785+
if failure:
786+
raise RuntimeError(
787+
"Suffered a failure during distributed training. Please see the training logs for more context."
788+
)
784789

785790

786791
if __name__ == "__main__":

tox.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ commands =
6666
sh -c 'git diff --exit-code || (echo "pyproject.toml formatting is incorrect. Please run \"make toml-fmt\" and commit the changes." && exit 1)'
6767
allowlist_externals = make, sh
6868

69-
7069
[testenv:spellcheck]
7170
description = spell check (needs 'aspell' command)
7271
basepython = {[testenv:py3]basepython}

0 commit comments

Comments
 (0)