@@ -77,6 +77,7 @@ def simple_evaluate(
77
77
torch_random_seed : int = 1234 ,
78
78
fewshot_random_seed : int = 1234 ,
79
79
datetime_str : str = get_datetime_str (),
80
+ distributed_executor_backend : str = "accelerate" ,
80
81
cli_args = None ,
81
82
):
82
83
"""Instantiate and evaluate a model on a list of tasks.
@@ -133,7 +134,8 @@ def simple_evaluate(
133
134
Random seed for torch. If set to None, the seed will not be set.
134
135
:param fewshot_random_seed: int
135
136
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
136
-
137
+ :param distributed_executor_backend: str
138
+ The backend to use for distributed execution, `accelerate` or `torchrun`. Defaults to "accelerate" for the `accelerate` library.
137
139
:return
138
140
Dictionary of results
139
141
"""
@@ -156,6 +158,8 @@ def simple_evaluate(
156
158
157
159
assert tasks != [], "No tasks specified, or no tasks found. Please verify the task names."
158
160
161
+ assert distributed_executor_backend in {"accelerate" , "torchrun" }, f"Invalid distributed executor backend: { distributed_executor_backend } . Choose either 'accelerate' or 'torchrun'."
162
+
159
163
if gen_kwargs :
160
164
gen_kwargs = simple_parse_args_string (gen_kwargs )
161
165
eval_logger .warning (f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks." )
@@ -258,6 +262,7 @@ def _adjust_config(task_dict):
258
262
apply_chat_template = apply_chat_template ,
259
263
fewshot_as_multiturn = fewshot_as_multiturn ,
260
264
verbosity = verbosity ,
265
+ distributed_executor_backend = distributed_executor_backend ,
261
266
cli_args = cli_args ,
262
267
)
263
268
@@ -319,6 +324,7 @@ def evaluate(
319
324
apply_chat_template : bool = False ,
320
325
fewshot_as_multiturn : bool = False ,
321
326
verbosity : str = "INFO" ,
327
+ distributed_executor_backend : str = "accelerate" ,
322
328
cli_args = None ,
323
329
):
324
330
"""Instantiate and evaluate a model on a list of tasks.
@@ -341,6 +347,8 @@ def evaluate(
341
347
If True, apply chat template to the prompt
342
348
:param fewshot_as_multiturn: bool
343
349
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
350
+ :param distributed_executor_backend: str
351
+ The backend to use for distributed execution, `accelerate` or `torchrun`. Defaults to "accelerate" for the `accelerate` library.
344
352
:return
345
353
Dictionary of results
346
354
"""
@@ -432,8 +440,17 @@ def evaluate(
432
440
requests [reqtype ].append (instance )
433
441
434
442
if lm .world_size > 1 :
435
- instances_rnk = torch .tensor (len (task ._instances ), device = lm .device )
436
- gathered_item = lm .accelerator .gather (instances_rnk ).cpu ().detach ().numpy ().tolist ()
443
+ if distributed_executor_backend == "accelerate" :
444
+ instances_rnk = torch .tensor (len (task ._instances ), device = lm .device )
445
+ gathered_item = lm .accelerator .gather (instances_rnk ).cpu ().detach ().numpy ().tolist ()
446
+ elif distributed_executor_backend == "torchrun" :
447
+ instances_rnk = torch .tensor (len (task ._instances ), device = lm .device )
448
+ gathered_item = torch .zeros (lm .world_size * 1 , dtype = instances_rnk .dtype , device = lm .device )
449
+ dist .all_gather_into_tensor (gathered_item , instances_rnk )
450
+ gathered_item = gathered_item .cpu ().detach ().numpy ().tolist ()
451
+ else :
452
+ raise ValueError (f"Invalid distributed_executor_backend: { distributed_executor_backend } . Choose either 'accelerate' or 'torchrun'." )
453
+
437
454
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
438
455
reqtype = "loglikelihood" if task .OUTPUT_TYPE == "multiple_choice" else task .OUTPUT_TYPE
439
456
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
@@ -462,7 +479,12 @@ def evaluate(
462
479
req .resps .append (x )
463
480
464
481
if lm .world_size > 1 :
465
- lm .accelerator .wait_for_everyone ()
482
+ if distributed_executor_backend == "accelerate" :
483
+ lm .accelerator .wait_for_everyone ()
484
+ elif distributed_executor_backend == "torchrun" :
485
+ dist .barrier ()
486
+ else :
487
+ raise ValueError (f"Invalid distributed_executor_backend: { distributed_executor_backend } . Choose either 'accelerate' or 'torchrun'." )
466
488
467
489
RANK = lm .rank
468
490
WORLD_SIZE = lm .world_size
@@ -638,8 +660,12 @@ def evaluate(
638
660
else :
639
661
results_dict = None
640
662
641
- if hasattr (lm , "accelerator" ):
663
+ if hasattr (lm , "accelerator" ) and distributed_executor_backend == "accelerate" :
642
664
lm .accelerator .wait_for_everyone ()
665
+ elif distributed_executor_backend == "torchrun" :
666
+ dist .barrier ()
667
+ else :
668
+ raise ValueError (f"Invalid distributed_executor_backend: { distributed_executor_backend } . Choose either 'accelerate' or 'torchrun'." )
643
669
644
670
return results_dict
645
671
0 commit comments