Skip to content

Commit 31dce05

Browse files
committed
import in getters, so registry is populated
1 parent 99e5953 commit 31dce05

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

lm_eval/api/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def decorate(cls):
486486
return decorate
487487

488488

489-
def get_model(model_name):
489+
def get_model(model_name: str):
490490
"""Get a model class by name.
491491
492492
Args:

lm_eval/evaluator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import TYPE_CHECKING, List, Optional, Union
99

1010
import numpy as np
11-
import torch
1211

1312
import lm_eval.api.metrics
1413
import lm_eval.api.registry
@@ -33,6 +32,7 @@
3332
hash_dict_images,
3433
hash_string,
3534
positional_deprecated,
35+
set_torch_seed,
3636
setup_logging,
3737
simple_parse_args_string,
3838
wrap_text,
@@ -193,7 +193,7 @@ def simple_evaluate(
193193

194194
if torch_random_seed is not None:
195195
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
196-
torch.manual_seed(torch_random_seed)
196+
set_torch_seed(torch_random_seed)
197197

198198
if fewshot_random_seed is not None:
199199
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
@@ -553,6 +553,8 @@ def evaluate(
553553
requests[reqtype].append(instance)
554554

555555
if lm.world_size > 1:
556+
import torch
557+
556558
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
557559
gathered_item = (
558560
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
@@ -661,6 +663,8 @@ def evaluate(
661663
task_output.sample_metrics[(metric, filter_key)].append(value)
662664

663665
if WORLD_SIZE > 1:
666+
import torch
667+
664668
# if multigpu, then gather data across all ranks to rank 0
665669
# first gather logged samples across all ranks
666670
for task_output in eval_tasks:

lm_eval/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,3 +840,9 @@ def _request_with_retries(method, url, **kwargs):
840840
return False
841841

842842
return True
843+
844+
845+
def set_torch_seed(seed: int):
846+
import torch
847+
848+
torch.manual_seed(seed)

0 commit comments

Comments
 (0)