9
9
from functools import partial
10
10
11
11
import numpy as np
12
+ import torch
12
13
import yaml
13
14
14
15
warnings .simplefilter ("ignore" , category = DeprecationWarning )
@@ -314,13 +315,17 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
314
315
else :
315
316
args_list .append (args )
316
317
317
- # initialize Accelerator
318
- kwargs_handler = InitProcessGroupKwargs (timeout = datetime .timedelta (seconds = 60000 ))
319
- accelerator = Accelerator (kwargs_handlers = [kwargs_handler ])
320
- if accelerator .is_main_process :
321
- is_main_process = True
318
+ # initialize Accelerator only if not already in a distributed context
319
+ if torch .distributed .is_available () and torch .distributed .is_initialized ():
320
+ accelerator = None
321
+ is_main_process = torch .distributed .get_rank () == 0
322
322
else :
323
- is_main_process = False
323
+ kwargs_handler = InitProcessGroupKwargs (timeout = datetime .timedelta (seconds = 60000 ))
324
+ accelerator = Accelerator (kwargs_handlers = [kwargs_handler ])
325
+ if accelerator .is_main_process :
326
+ is_main_process = True
327
+ else :
328
+ is_main_process = False
324
329
325
330
for args in args_list :
326
331
try :
@@ -330,7 +335,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
330
335
results , samples = cli_evaluate_single (args )
331
336
results_list .append (results )
332
337
333
- accelerator .wait_for_everyone ()
338
+ if accelerator :
339
+ accelerator .wait_for_everyone ()
340
+ elif torch .distributed .is_available () and torch .distributed .is_initialized ():
341
+ torch .distributed .barrier ()
334
342
if is_main_process and args .wandb_args :
335
343
try :
336
344
wandb_logger .post_init (results )
@@ -495,6 +503,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
495
503
fewshot_random_seed = args .seed [3 ],
496
504
cli_args = args ,
497
505
datetime_str = datetime_str ,
506
+ distributed_executor_backend = "torchrun" if (torch .distributed .is_available () and torch .distributed .is_initialized ()) else "accelerate" ,
498
507
** request_caching_args ,
499
508
)
500
509
0 commit comments