|
5 | 5 | from pathlib import Path |
6 | 6 | import argparse |
7 | 7 | import datetime |
| 8 | +import functools |
8 | 9 | import logging |
9 | 10 | import math |
10 | 11 | import os |
@@ -544,6 +545,28 @@ def train( |
544 | 545 | ) |
545 | 546 |
|
546 | 547 |
|
| 548 | +# This function makes an effort to stick to a default value from torch library, |
| 549 | +# whatever it may be. That's why we don't just set to the current (as of the |
| 550 | +# time of writing) default: to cover the unlikely event torch decides to tweak |
| 551 | +# the default. |
| 552 | +def _get_collective_timeout() -> datetime.timedelta | None: |
| 553 | + timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS") |
| 554 | + if timeout_var is None: |
| 555 | + return None |
| 556 | + |
| 557 | + try: |
| 558 | + timeout = int(timeout_var) |
| 559 | + except ValueError: |
| 560 | + timeout = -1 |
| 561 | + |
| 562 | + if timeout <= 0: |
| 563 | + raise ValueError( |
| 564 | + f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer." |
| 565 | + ) |
| 566 | + |
| 567 | + return datetime.timedelta(milliseconds=timeout) |
| 568 | + |
| 569 | + |
547 | 570 | def main(args): |
548 | 571 | if args.distributed_training_framework == "deepspeed" and not FusedAdam: |
549 | 572 | raise ImportError( |
@@ -571,15 +594,17 @@ def main(args): |
571 | 594 | model_conf = AutoConfig.from_pretrained(args.model_name_or_path) |
572 | 595 | args.model_type = model_conf.model_type |
573 | 596 |
|
574 | | - # solution discovered from torchtune https://github.com/pytorch/torchtune/issues/2093 |
575 | | - # gets converted to a timedelta of 1:40:00 if the default is kept |
576 | | - nccl_timeout = int(os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS", "6000000")) |
577 | 597 | #### distributed init ##### |
578 | 598 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) |
579 | 599 | args.local_rank = int(os.environ["LOCAL_RANK"]) |
580 | | - torch.distributed.init_process_group( |
581 | | - "nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout) |
582 | | - ) |
| 600 | + |
| 601 | + timeout = _get_collective_timeout() |
| 602 | + init = functools.partial(torch.distributed.init_process_group, "nccl") |
| 603 | + if timeout is not None: |
| 604 | + init(timeout=timeout) |
| 605 | + else: |
| 606 | + init() |
| 607 | + |
583 | 608 | args.global_rank = torch.distributed.get_rank() |
584 | 609 | tensor = torch.ByteTensor([False]).cuda() |
585 | 610 | torch.distributed.all_reduce(tensor) |
|
0 commit comments