Skip to content

Commit e76db8c

Browse files
authored
Refactor Accelerator initialization in cli_evaluate (#717)
1 parent 3d9ccbf commit e76db8c

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

lmms_eval/__main__.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import partial
1010

1111
import numpy as np
12+
import torch
1213
import yaml
1314

1415
warnings.simplefilter("ignore", category=DeprecationWarning)
@@ -314,13 +315,17 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
314315
else:
315316
args_list.append(args)
316317

317-
# initialize Accelerator
318-
kwargs_handler = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=60000))
319-
accelerator = Accelerator(kwargs_handlers=[kwargs_handler])
320-
if accelerator.is_main_process:
321-
is_main_process = True
318+
# initialize Accelerator only if not already in a distributed context
319+
if torch.distributed.is_available() and torch.distributed.is_initialized():
320+
accelerator = None
321+
is_main_process = torch.distributed.get_rank() == 0
322322
else:
323-
is_main_process = False
323+
kwargs_handler = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=60000))
324+
accelerator = Accelerator(kwargs_handlers=[kwargs_handler])
325+
if accelerator.is_main_process:
326+
is_main_process = True
327+
else:
328+
is_main_process = False
324329

325330
for args in args_list:
326331
try:
@@ -330,7 +335,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
330335
results, samples = cli_evaluate_single(args)
331336
results_list.append(results)
332337

333-
accelerator.wait_for_everyone()
338+
if accelerator:
339+
accelerator.wait_for_everyone()
340+
elif torch.distributed.is_available() and torch.distributed.is_initialized():
341+
torch.distributed.barrier()
334342
if is_main_process and args.wandb_args:
335343
try:
336344
wandb_logger.post_init(results)
@@ -495,6 +503,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
495503
fewshot_random_seed=args.seed[3],
496504
cli_args=args,
497505
datetime_str=datetime_str,
506+
distributed_executor_backend="torchrun" if (torch.distributed.is_available() and torch.distributed.is_initialized()) else "accelerate",
498507
**request_caching_args,
499508
)
500509

0 commit comments

Comments
 (0)