77from ast import literal_eval
88from collections .abc import Sequence
99from enum import StrEnum
10- from typing import TYPE_CHECKING , Annotated , Any , ClassVar , Literal , Self , TypeVar , cast
10+ from typing import (
11+ TYPE_CHECKING ,
12+ Annotated ,
13+ Any ,
14+ ClassVar ,
15+ Literal ,
16+ Self ,
17+ TypeAlias ,
18+ TypeVar ,
19+ cast ,
20+ overload ,
21+ )
1122from uuid import UUID
1223
1324from pydantic import BaseModel , ConfigDict , Field , GetCoreSchemaHandler , model_validator
2132if TYPE_CHECKING :
2233 import numpy as np
2334
35+ # Work around super weird bug where np.random.Generator in quotes
36+ # is not being respected as a forward reference
37+ try :
38+ SeedTypes : TypeAlias = "int | random.Random | np.random.Generator | None"
39+ except ImportError : # NumPy isn't installed
40+ SeedTypes = int | random .Random | None # type: ignore[misc]
41+
2442
2543DEFAULT_EVAL_MODEL_NAME = "gpt-4o-mini"
2644LLM_BOOL_EVAL_CONFIG : dict [str , Any ] = {
@@ -75,7 +93,7 @@ def get_default_config(self) -> dict[str, Any]:
7593 return {}
7694
7795
78- def partial_format (value : str , ** formats : dict [ str , Any ] ) -> str :
96+ def partial_format (value : str , ** formats ) -> str :
7997 """Partially format a string given a variable amount of formats."""
8098 for template_key , template_value in formats .items ():
8199 with contextlib .suppress (KeyError ):
@@ -253,9 +271,15 @@ def val_func(
253271T = TypeVar ("T" )
254272
255273
256- def shuffle (
257- value : Sequence [T ], seed : "int | random.Random | np.random.Generator | None" = None
258- ) -> Sequence [T ]:
274+ @overload
275+ def shuffle (value : "np.ndarray" , seed : SeedTypes = None ) -> "np.ndarray" : ...
276+
277+
278+ @overload
279+ def shuffle (value : Sequence [T ], seed : SeedTypes = None ) -> Sequence [T ]: ...
280+
281+
282+ def shuffle (value , seed : SeedTypes = None ):
259283 """Shuffle a non-mutable sequence."""
260284 # Since most shuffle fn's are in-place, we employ sampling without replacement
261285 if isinstance (seed , int ):
@@ -265,7 +289,7 @@ def shuffle(
265289 if seed is None :
266290 return random .sample (value , k = len (value ))
267291 # Numpy RNG. Note this will have a type error for sequences like str, but oh well
268- return seed .choice (value , size = len (value ), replace = False ) # type: ignore[arg-type,return-value]
292+ return seed .choice (value , size = len (value ), replace = False )
269293
270294
271295_CAPITAL_A_INDEX = ord ("A" )
0 commit comments