@@ -133,6 +133,7 @@ def __init__(self,
133133 num_samples : Optional [int ] = None ,
134134 random_seed : int = 0 ,
135135 apply_chat_template : bool = False ,
136+ fewshot_as_multiturn : bool = False ,
136137 system_prompt : Optional [str ] = None ):
137138 try :
138139 import lm_eval
@@ -141,8 +142,10 @@ def __init__(self,
141142 f"Evaluation task { self .__class__ .__name__ } requires `lm_eval`. "
142143 "Please install the package first, e.g., `pip install lm_eval`."
143144 ) from e
145+ import lm_eval .tasks
144146 super ().__init__ (random_seed = random_seed ,
145147 apply_chat_template = apply_chat_template ,
148+ fewshot_as_multiturn = fewshot_as_multiturn ,
146149 system_prompt = system_prompt )
147150 self .task_name = task_name
148151 self .dataset_path = dataset_path
@@ -190,14 +193,16 @@ def compute_score(self, outputs: List[RequestOutput], references: List[str],
190193 def evaluate (self ,
191194 llm : Union [LLM , PyTorchLLM ],
192195 sampling_params : Optional [SamplingParams ] = None ,
193- streaming : bool = False ) -> float :
196+ streaming : bool = False ,
197+ scores_filter : str = None ) -> float :
194198 import lm_eval
195- results = lm_eval .evaluate (lm = LmEvalWrapper (llm , sampling_params ,
196- streaming ),
197- task_dict = self .task_dict ,
198- limit = self .num_samples ,
199- apply_chat_template = self .apply_chat_template ,
200- system_instruction = self .system_prompt )
199+ results = lm_eval .evaluate (
200+ lm = LmEvalWrapper (llm , sampling_params , streaming ),
201+ task_dict = self .task_dict ,
202+ limit = self .num_samples ,
203+ apply_chat_template = self .apply_chat_template ,
204+ fewshot_as_multiturn = self .fewshot_as_multiturn ,
205+ system_instruction = self .system_prompt )
201206 # Normalize scores to range 0~100
202207 scores = results ["results" ][self .task_name ]
203208 for metric in scores .keys ():
@@ -206,12 +211,17 @@ def evaluate(self,
206211 logger .info (
207212 f"lm-eval { self .task_name } results (scores normalized to range 0~100):\n { lm_eval .utils .make_table (results )} "
208213 )
209-
210- average_acc = np .mean (
211- [acc for m , acc in scores .items () if "_stderr" not in m ])
212- logger .info (
213- f"lm-eval { self .task_name } average accuracy: { average_acc :.2f} " )
214- return average_acc
214+ if scores_filter is not None :
215+ result_acc = results ["results" ][self .task_name ][scores_filter ]
216+ logger .info (
217+ f"lm-eval { self .task_name } { scores_filter } accuracy: { result_acc :.2f} "
218+ )
219+ else :
220+ result_acc = np .mean (
221+ [acc for m , acc in scores .items () if "_stderr" not in m ])
222+ logger .info (
223+ f"lm-eval { self .task_name } average accuracy: { result_acc :.2f} " )
224+ return result_acc
215225
216226 @classmethod
217227 def command_harness (cls , ctx , ** kwargs ):
@@ -221,6 +231,8 @@ def command_harness(cls, ctx, **kwargs):
221231 random_seed = kwargs .pop ("random_seed" , 0 ),
222232 apply_chat_template = kwargs .pop ("apply_chat_template" ,
223233 False ),
234+ fewshot_as_multiturn = kwargs .pop ("fewshot_as_multiturn" ,
235+ False ),
224236 system_prompt = kwargs .pop ("system_prompt" , None ))
225237 sampling_params = SamplingParams (
226238 max_tokens = kwargs .pop ("max_output_length" ),
@@ -254,6 +266,10 @@ def __init__(self, **kwargs):
254266 is_flag = True ,
255267 default = False ,
256268 help = "Whether to apply chat template." )
269+ @click .option ("--fewshot_as_multiturn" ,
270+ is_flag = True ,
271+ default = False ,
272+ help = "Apply fewshot as multiturn." )
257273 @click .option ("--system_prompt" ,
258274 type = str ,
259275 default = None ,
@@ -269,6 +285,10 @@ def __init__(self, **kwargs):
269285 @click .pass_context
270286 @staticmethod
271287 def command (ctx , ** kwargs ) -> None :
288+ if kwargs .get ("fewshot_as_multiturn" , False ):
289+ assert kwargs .get (
290+ "apply_chat_template" , False
291+ ), "apply_chat_template must be True when fewshot_as_multiturn is True"
272292 GSM8K .command_harness (ctx , ** kwargs )
273293
274294
0 commit comments