1- import logging
2- import warnings
3- from functools import partial
4- from typing import TYPE_CHECKING , Iterable , Optional , Union
1+ from __future__ import annotations
52
6- import datasets
3+ import logging
4+ from random import Random
5+ from typing import TYPE_CHECKING
76
87
98if TYPE_CHECKING :
10- from random import Random
9+ from collections .abc import Iterable , Sequence
10+ from typing import Any , TypeVar
1111
12- from lm_eval . api . task import ConfigurableTask , Task
12+ _T = TypeVar ( "_T" )
1313
14- eval_logger = logging .getLogger ("lm-eval" )
14+ eval_logger = logging .getLogger (__name__ )
1515
1616
1717class ContextSampler :
1818 def __init__ (
1919 self ,
20- docs : list [dict ],
21- task : Union ["Task" , "ConfigurableTask" ],
22- fewshot_indices : Optional [Iterable ] = None ,
23- rnd : Optional ["Random" ] = None ,
20+ df : Sequence [dict [str , Any ]] | None = None ,
21+ * ,
22+ rnd : int | None = None ,
23+ fewshot_indices : list [int ] | None = None ,
24+ ** kwargs ,
2425 ) -> None :
25- self .rnd = rnd
26- if not self .rnd :
27- raise ValueError (
28- "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
29- )
26+ self .rnd = Random (rnd )
27+ self .df = df or []
28+ self .fewshot_indices = fewshot_indices
29+ self ._loaded = False # to iterate over fewshot_indices when needed
3030
31- self .task = task
32- self .config = task ._config
31+ def sample (
32+ self ,
33+ n : int ,
34+ eval_doc : dict [str , Any ] | None = None ,
35+ df : Sequence [dict [str , Any ]] | None = None ,
36+ ** kwargs ,
37+ ) -> Sequence [dict [str , Any ]]:
38+ """
39+ Sample n documents from the pool.
3340
34- self .target_delimiter = self .config .target_delimiter
35- self .fewshot_delimiter = self .config .fewshot_delimiter
41+ Args:
42+ n: Number of documents to sample
43+ eval_doc: Optional document to exclude from sampling
44+ df: Optional list of documents to sample from
3645
37- if (
38- self .config .fewshot_config is not None
39- and self .config .fewshot_config .get ("doc_to_text" , None ) is not None
40- ):
41- self .doc_to_text = partial (
42- self .task .doc_to_text ,
43- doc_to_text = self .config .fewshot_config .get ("doc_to_text" , None ),
44- )
45- else :
46- self .doc_to_text = self .task .doc_to_text
47-
48- if (
49- self .config .fewshot_config is not None
50- and self .config .fewshot_config .get ("doc_to_target" , None ) is not None
51- ):
52- self .doc_to_target = partial (
53- self .task .doc_to_target ,
54- doc_to_target = self .config .fewshot_config .get ("doc_to_target" , None ),
55- )
56- else :
57- self .doc_to_target = self .task .doc_to_target
58-
59- if (
60- self .config .fewshot_config is not None
61- and self .config .fewshot_config .get ("doc_to_choice" , None ) is not None
62- ):
63- self .doc_to_choice = partial (
64- self .task .doc_to_choice ,
65- doc_to_choice = self .config .fewshot_config .get ("doc_to_choice" , None ),
46+ Returns:
47+ List of sampled documents
48+ """
49+ assert n >= 0 , "Error: number of samples requested must be >=0"
50+ if n == 0 :
51+ return []
52+
53+ if df :
54+ self .df = df
55+
56+ assert self .df , "Error: no documents available for sampling."
57+ res = (
58+ self .rnd .sample (self .fewshot_docs (), n )
59+ if not eval_doc
60+ else self .rm_eval_doc (
61+ eval_doc , self .rnd .sample (self .fewshot_docs (), n + 1 ), n
6662 )
67- else :
68- self .doc_to_choice = self .task .doc_to_choice
69-
70- self .docs = docs # HF dataset split, provided by task._fewshot_docs()
71- if fewshot_indices : # subset few-shot docs from
72- if not isinstance (self .docs , datasets .Dataset ):
73- raise ValueError (
74- "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
75- )
76- self .docs = self .docs .select (fewshot_indices )
77-
78- def get_context (self , doc : dict , num_fewshot : int , gen_prefix : str = None ):
79- # draw an extra fewshot sample if using same split as evaluating on
80- prefix = gen_prefix + " " if gen_prefix else ""
81- n_samples = (
82- num_fewshot + 1
83- if self .config .fewshot_split == self .config .test_split
84- else num_fewshot
8563 )
86-
87- # draw `n_samples` docs from fewshot_docs
88- fewshotex = self .sample (n_samples )
89-
90- # get rid of the doc that's the one we're evaluating, if it's in the fewshot
91- # TODO: should we just stop people from using fewshot from same split as evaluating?
92- selected_docs = [x for x in fewshotex if x != doc ][:num_fewshot ]
93-
94- labeled_examples = ""
95- for doc in selected_docs :
96- doc_content = self .doc_to_text (doc )
97- doc_target = self .doc_to_target (doc )
98- if self .config .doc_to_choice is None or isinstance (doc_content , str ):
99- labeled_examples += doc_content
100- else :
101- labeled_examples += self .doc_to_choice (doc )[doc_content ]
102-
103- if doc_target != "" :
104- if self .target_delimiter .isspace () and str (doc_target )[0 ].isspace ():
105- # TODO: add logger warn once here.
106- warnings .warn (
107- "Both target_delimiter and target start with a space. This may cause issues." ,
108- Warning ,
109- stacklevel = 2 ,
110- )
111- labeled_examples += self .target_delimiter
112- labeled_examples += prefix
113- labeled_examples += (
114- str (doc_target [0 ])
115- if isinstance (doc_target , list )
116- else doc_target
117- if self .config .doc_to_choice is None or isinstance (doc_target , str )
118- else str (self .doc_to_choice (doc )[doc_target ])
119- )
120- labeled_examples += self .fewshot_delimiter
121-
122- return labeled_examples
123-
124- def get_chat_context (
125- self ,
126- doc : dict ,
127- num_fewshot : int ,
128- fewshot_as_multiturn : bool = False ,
129- gen_prefix : Optional [str ] = None ,
130- ):
131- # TODO: Do we need any other delimiter
132- prefix = gen_prefix + " " if gen_prefix else ""
133- chat_history = []
134- # draw an extra fewshot sample if using same split as evaluating on
135- n_samples = (
136- num_fewshot + 1
137- if self .config .fewshot_split == self .config .test_split
138- else num_fewshot
64+ assert len (res ) == n , (
65+ f"Error: number of fewshot samples returned ({ len (res )} ) not equal to number requested ({ n } )."
66+ )
67+ return res
68+
69+ def set_rnd (self , rnd : int | None ):
70+ self .rnd = Random (rnd )
71+ return self
72+
73+ def replace_df (self , df : Sequence [dict [str , Any ]]):
74+ self .df = df
75+ self ._loaded = False
76+ return self
77+
78+ def fewshot_docs (self ):
79+ """Return cached fewshot docs if available"""
80+ if self ._loaded :
81+ return self .df
82+ if self .fewshot_indices and self .df and not self ._loaded :
83+ self .df = [self .df [i ] for i in self .fewshot_indices ]
84+ self ._loaded = True
85+ return list (self .df )
86+
87+ @staticmethod
88+ def rm_eval_doc (doc : _T , _iter : Iterable [_T ], n = None ) -> Sequence [_T ]:
89+ return (
90+ [x for x in _iter if x != doc ]
91+ if n is None
92+ else [x for x in _iter if x != doc ][:n ]
13993 )
140- # draw `n_samples` docs from fewshot_docs
141- fewshotex = self .sample (n_samples )
142-
143- # get rid of the doc that's the one we're evaluating, if it's in the fewshot
144- # TODO: should we just stop people from using fewshot from same split as evaluating?
145- selected_docs = [x for x in fewshotex if x != doc ][:num_fewshot ]
146-
147- if fewshot_as_multiturn :
148- for doc in selected_docs :
149- doc_content = self .doc_to_text (doc )
150- doc_target = self .doc_to_target (doc )
151- chat_history .append (
152- {
153- "role" : "user" ,
154- "content" : doc_content
155- if self .config .doc_to_choice is None
156- or isinstance (doc_content , str )
157- else self .doc_to_choice (doc )[doc_content ],
158- }
159- )
160- chat_history .append (
161- {
162- "role" : "assistant" ,
163- "content" : prefix + str (doc_target [0 ])
164- if isinstance (doc_target , list )
165- else prefix + doc_target
166- if self .config .doc_to_choice is None
167- or isinstance (doc_target , str )
168- else prefix + str (self .doc_to_choice (doc )[doc_target ]),
169- }
170- )
171- else :
172- # get fewshot context as one user turn
173- chat_history .append (
174- {
175- "role" : "user" ,
176- "content" : self .get_context (
177- doc , num_fewshot , gen_prefix = gen_prefix
178- ),
179- }
180- )
181-
182- return chat_history
183-
184- def sample (self , n : int ):
185- """
186- Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
187- """
188-
189- return self .rnd .sample (self .docs , n )
19094
19195
19296class FirstNSampler (ContextSampler ):
193- def sample (self , n : int ) -> None :
97+ def sample (self , n : int , eval_doc = None , df = None , ** kwargs ) :
19498 """
19599 Draw the first `n` samples in order from the specified split.
196100 Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
197101 """
198- assert n <= len (self .docs ), (
199- f"Error: number of fewshot samples requested exceeds the { len (self .docs )} that are available."
102+ assert n <= len (self .df ), (
103+ f"Error: number of fewshot samples requested exceeds the { len (self .df )} that are available."
200104 )
201- return self .docs [:n ]
105+ return self .df [:n ]
202106
203107
204108class BalancedSampler (ContextSampler ):
205- def sample (self , n : int ) -> None :
109+ def sample (self , n : int , eval_doc = None , df = None , ** kwargs ) :
206110 """
207111 TODO: this should return approximately class-balanced samples from our fewshot examples.
208112 TODO: what order should they be in? maybe random?
209113 """
210114
211- pass
115+ raise NotImplementedError
212116
213117
214118class ManualSampler (ContextSampler ):
215- def sample (self , n : int ) -> None :
119+ def sample (self , n : int , eval_doc = None , df = None , ** kwargs ) :
216120 """ """
217- pass
121+ raise NotImplementedError
218122
219123
220- SAMPLER_REGISTRY = {
124+ SAMPLER_REGISTRY : dict [ str , type [ ContextSampler ]] = {
221125 "default" : ContextSampler ,
222126 "first_n" : FirstNSampler ,
223127}
@@ -226,7 +130,7 @@ def sample(self, n: int) -> None:
226130def get_sampler (name : str ):
227131 try :
228132 return SAMPLER_REGISTRY [name ]
229- except KeyError :
230- raise ValueError (
133+ except KeyError as e :
134+ raise KeyError (
231135 f"Attempted to use contextsampler '{ name } ', but no sampling strategy for this name found! Supported model names: { ', ' .join (SAMPLER_REGISTRY .keys ())} "
232- )
136+ ) from e
0 commit comments