Skip to content

Commit 4036555

Browse files
authored
support distributed executor backend - torchrun (#680)
* support dist executor - torchrun * support dist executor - torchrun
1 parent cd1d194 commit 4036555

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

lmms_eval/evaluator.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def simple_evaluate(
7777
torch_random_seed: int = 1234,
7878
fewshot_random_seed: int = 1234,
7979
datetime_str: str = get_datetime_str(),
80+
distributed_executor_backend: str = "accelerate",
8081
cli_args=None,
8182
):
8283
"""Instantiate and evaluate a model on a list of tasks.
@@ -133,7 +134,8 @@ def simple_evaluate(
133134
Random seed for torch. If set to None, the seed will not be set.
134135
:param fewshot_random_seed: int
135136
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
136-
137+
:param distributed_executor_backend: str
138+
The backend to use for distributed execution, `accelerate` or `torchrun`. Defaults to "accelerate" for the `accelerate` library.
137139
:return
138140
Dictionary of results
139141
"""
@@ -156,6 +158,8 @@ def simple_evaluate(
156158

157159
assert tasks != [], "No tasks specified, or no tasks found. Please verify the task names."
158160

161+
assert distributed_executor_backend in {"accelerate", "torchrun"}, f"Invalid distributed executor backend: {distributed_executor_backend}. Choose either 'accelerate' or 'torchrun'."
162+
159163
if gen_kwargs:
160164
gen_kwargs = simple_parse_args_string(gen_kwargs)
161165
eval_logger.warning(f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.")
@@ -258,6 +262,7 @@ def _adjust_config(task_dict):
258262
apply_chat_template=apply_chat_template,
259263
fewshot_as_multiturn=fewshot_as_multiturn,
260264
verbosity=verbosity,
265+
distributed_executor_backend=distributed_executor_backend,
261266
cli_args=cli_args,
262267
)
263268

@@ -319,6 +324,7 @@ def evaluate(
319324
apply_chat_template: bool = False,
320325
fewshot_as_multiturn: bool = False,
321326
verbosity: str = "INFO",
327+
distributed_executor_backend: str = "accelerate",
322328
cli_args=None,
323329
):
324330
"""Instantiate and evaluate a model on a list of tasks.
@@ -341,6 +347,8 @@ def evaluate(
341347
If True, apply chat template to the prompt
342348
:param fewshot_as_multiturn: bool
343349
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
350+
:param distributed_executor_backend: str
351+
The backend to use for distributed execution, `accelerate` or `torchrun`. Defaults to "accelerate" for the `accelerate` library.
344352
:return
345353
Dictionary of results
346354
"""
@@ -432,8 +440,17 @@ def evaluate(
432440
requests[reqtype].append(instance)
433441

434442
if lm.world_size > 1:
435-
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
436-
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
443+
if distributed_executor_backend == "accelerate":
444+
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
445+
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
446+
elif distributed_executor_backend == "torchrun":
447+
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
448+
gathered_item = torch.zeros(lm.world_size * 1, dtype=instances_rnk.dtype, device=lm.device)
449+
dist.all_gather_into_tensor(gathered_item, instances_rnk)
450+
gathered_item = gathered_item.cpu().detach().numpy().tolist()
451+
else:
452+
raise ValueError(f"Invalid distributed_executor_backend: {distributed_executor_backend}. Choose either 'accelerate' or 'torchrun'.")
453+
437454
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
438455
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE
439456
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
@@ -462,7 +479,12 @@ def evaluate(
462479
req.resps.append(x)
463480

464481
if lm.world_size > 1:
465-
lm.accelerator.wait_for_everyone()
482+
if distributed_executor_backend == "accelerate":
483+
lm.accelerator.wait_for_everyone()
484+
elif distributed_executor_backend == "torchrun":
485+
dist.barrier()
486+
else:
487+
raise ValueError(f"Invalid distributed_executor_backend: {distributed_executor_backend}. Choose either 'accelerate' or 'torchrun'.")
466488

467489
RANK = lm.rank
468490
WORLD_SIZE = lm.world_size
@@ -638,8 +660,12 @@ def evaluate(
638660
else:
639661
results_dict = None
640662

641-
if hasattr(lm, "accelerator"):
663+
if hasattr(lm, "accelerator") and distributed_executor_backend == "accelerate":
642664
lm.accelerator.wait_for_everyone()
665+
elif distributed_executor_backend == "torchrun":
666+
dist.barrier()
667+
else:
668+
raise ValueError(f"Invalid distributed_executor_backend: {distributed_executor_backend}. Choose either 'accelerate' or 'torchrun'.")
643669

644670
return results_dict
645671

0 commit comments

Comments
 (0)