We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e316f86 commit e1a8ad4Copy full SHA for e1a8ad4
lmms_eval/evaluator.py
@@ -13,6 +13,7 @@
13
import numpy as np
14
import torch
15
import torch.distributed as dist
16
+from accelerate import Accelerator
17
from datasets import Image, Sequence
18
from loguru import logger as eval_logger
19
from tqdm import tqdm
@@ -660,8 +661,9 @@ def evaluate(
660
661
else:
662
results_dict = None
663
- if hasattr(lm, "accelerator") and distributed_executor_backend == "accelerate":
664
- lm.accelerator.wait_for_everyone()
+ if distributed_executor_backend == "accelerate":
665
+ # this should work for torchrun as well since it internally calls torch.distributed.barrier()
666
+ Accelerator().wait_for_everyone()
667
elif distributed_executor_backend == "torchrun":
668
dist.barrier()
669
0 commit comments