Skip to content

Commit 2178e44

Browse files
authored
Fixing shuffle's NameError: name 'np' is not defined and type errors (#286)
1 parent 6f45cae commit 2178e44

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

src/aviary/utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
from ast import literal_eval
88
from collections.abc import Sequence
99
from enum import StrEnum
10-
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, Self, TypeVar, cast
10+
from typing import (
11+
TYPE_CHECKING,
12+
Annotated,
13+
Any,
14+
ClassVar,
15+
Literal,
16+
Self,
17+
TypeAlias,
18+
TypeVar,
19+
cast,
20+
overload,
21+
)
1122
from uuid import UUID
1223

1324
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler, model_validator
@@ -21,6 +32,13 @@
2132
if TYPE_CHECKING:
2233
import numpy as np
2334

35+
# Work around super weird bug where np.random.Generator in quotes
36+
# is not being respected as a forward reference
37+
try:
38+
SeedTypes: TypeAlias = "int | random.Random | np.random.Generator | None"
39+
except ImportError: # NumPy isn't installed
40+
SeedTypes = int | random.Random | None # type: ignore[misc]
41+
2442

2543
DEFAULT_EVAL_MODEL_NAME = "gpt-4o-mini"
2644
LLM_BOOL_EVAL_CONFIG: dict[str, Any] = {
@@ -75,7 +93,7 @@ def get_default_config(self) -> dict[str, Any]:
7593
return {}
7694

7795

78-
def partial_format(value: str, **formats: dict[str, Any]) -> str:
96+
def partial_format(value: str, **formats) -> str:
7997
"""Partially format a string given a variable amount of formats."""
8098
for template_key, template_value in formats.items():
8199
with contextlib.suppress(KeyError):
@@ -253,9 +271,15 @@ def val_func(
253271
T = TypeVar("T")
254272

255273

256-
def shuffle(
257-
value: Sequence[T], seed: "int | random.Random | np.random.Generator | None" = None
258-
) -> Sequence[T]:
274+
@overload
275+
def shuffle(value: "np.ndarray", seed: SeedTypes = None) -> "np.ndarray": ...
276+
277+
278+
@overload
279+
def shuffle(value: Sequence[T], seed: SeedTypes = None) -> Sequence[T]: ...
280+
281+
282+
def shuffle(value, seed: SeedTypes = None):
259283
"""Shuffle a non-mutable sequence."""
260284
# Since most shuffle fn's are in-place, we employ sampling without replacement
261285
if isinstance(seed, int):
@@ -265,7 +289,7 @@ def shuffle(
265289
if seed is None:
266290
return random.sample(value, k=len(value))
267291
# Numpy RNG. Note this will have a type error for sequences like str, but oh well
268-
return seed.choice(value, size=len(value), replace=False) # type: ignore[arg-type,return-value]
292+
return seed.choice(value, size=len(value), replace=False)
269293

270294

271295
_CAPITAL_A_INDEX = ord("A")

0 commit comments

Comments
 (0)