11import logging
22import os
3+ import re
34import shutil
5+ import string
46from pathlib import Path
57from typing import Any , Literal
68
79import datasets
810from pydantic import Field
9- from tapeagents .core import Observation , StopStep , Thought
11+ from tapeagents .core import Action , Observation , StopStep , Thought
1012from tapeagents .environment import ContainerExecutor , StatefulTool , Tool
1113from tapeagents .steps import ImageObservation
1214from tapeagents .tools .browser import Browser
1315from tapeagents .tools .code_executor import CodeExecutor
1416from tapeagents .tools .media_reader import VideoReader
1517from tapeagents .tools .web_search import WebSearch
1618
17- from agentlab .benchmarks .abstract_env import AbstractBenchmark , AbstractEnvArgs
19+ from agentlab .benchmarks .abstract_env import AbstractBenchmark , SerializableEnvArgs
1820from agentlab .benchmarks .multitool_gym import MultiToolGym
1921
2022logger = logging .getLogger (__name__ )
@@ -38,14 +40,41 @@ def reset(self, seed=None) -> tuple[list[Observation], dict]:
3840 steps .append (image_obs )
3941 return steps , {}
4042
41- def step (self , action : str ) -> tuple [Observation , float , bool , bool , dict ]:
42- logger .info (f"env step called with action { type (action )} " )
43- return super ().step (action )
43+ def step (self , action : Action ) -> tuple [Observation , float , bool , bool , dict ]:
44+ logger .info (f"Gym step called with action { type (action )} " )
45+ observation , reward , terminated , truncated , env_info = super ().step (action )
46+ logger .info (f"Gym observation: { observation .short_view ()} " )
47+ return observation , reward , terminated , truncated , env_info
4448
49+ def calculate_reward (self , action : Action ) -> float :
50+ if isinstance (action , GaiaAnswer ):
51+ model_answer = action .answer
52+ ground_truth = self .task ["Final answer" ]
53+ reward = 1.0 if question_scorer (model_answer , ground_truth ) else 0.0
54+ else :
55+ reward = 0.0
4556
46- class GaiaGymArgs (AbstractEnvArgs , frozen = True ):
57+ if reward == 1.0 :
58+ logger .info (f"Task { self .task ['task_id' ]} solved." )
59+ else :
60+ logger .info (f"Task { self .task ['task_id' ]} failed." )
61+
62+ return reward
63+
64+
65+ class GaiaGymArgs (SerializableEnvArgs ):
4766 task : dict [str , Any ]
48- viewport_chars : int = 64000
67+ viewport_chars : int
68+ task_seed : int
69+ task_name : str
70+
71+ def __init__ (
72+ self , task_name : str , task : dict [str , Any ], viewport_chars : int = 64000 , task_seed : int = 0
73+ ):
74+ self .task_name = task_name
75+ self .task = task
76+ self .viewport_chars = viewport_chars
77+ self .task_seed = task_seed
4978
5079 def make_env (self , exp_dir : str | Path , action_mapping = None ) -> GaiaGym :
5180 exp_dir = str (exp_dir )
@@ -80,7 +109,7 @@ def model_post_init(self, __context: Any) -> None:
80109 self .env_args_list = []
81110 dataset = datasets .load_dataset ("gaia-benchmark/GAIA" , "2023_all" )[self .split ]
82111 for task in dataset :
83- env_args = GaiaGymArgs (task_name = "gaia_ " + task ["task_id" ], task = task )
112+ env_args = GaiaGymArgs (task_name = "gaia. " + task ["task_id" ], task = task )
84113 self .env_args_list .append (env_args )
85114
86115
@@ -143,3 +172,96 @@ class GaiaAnswer(StopStep):
143172 )
144173 answer : Any = Field (description = "Short final answer" )
145174 long_answer : str = Field (description = "Detailed final answer not restricted by format rules" )
175+
176+
177+ def normalize_number_str (number_str : str ) -> float :
178+ # we replace these common units and commas to allow
179+ # conversion to float
180+ for char in ["$" , "%" , "," ]:
181+ number_str = number_str .replace (char , "" )
182+ try :
183+ return float (number_str )
184+ except ValueError :
185+ logger .info (f"String { number_str } cannot be normalized to number str." )
186+ return float ("inf" )
187+
188+
189+ def split_string (
190+ s : str ,
191+ char_list : list [str ] = ["," , ";" ],
192+ ) -> list [str ]:
193+ pattern = f"[{ '' .join (char_list )} ]"
194+ return re .split (pattern , s )
195+
196+
197+ def question_scorer (
198+ model_answer : str ,
199+ ground_truth : str ,
200+ ) -> bool :
201+ def is_float (element : any ) -> bool :
202+ try :
203+ float (element )
204+ return True
205+ except ValueError :
206+ return False
207+
208+ # if gt is a number
209+ if is_float (ground_truth ):
210+ logger .info (f"Evaluating { model_answer } as a number." )
211+ normalized_answer = normalize_number_str (model_answer )
212+ return normalized_answer == float (ground_truth )
213+
214+ # if gt is a list
215+ elif any (char in ground_truth for char in ["," , ";" ]):
216+ logger .info (f"Evaluating { model_answer } as a comma separated list." )
217+ # question with the fish: normalization removes punct
218+
219+ gt_elems = split_string (ground_truth )
220+ ma_elems = split_string (model_answer )
221+
222+ # check length is the same
223+ if len (gt_elems ) != len (ma_elems ):
224+ logger .warning ("Answer lists have different lengths, returning False." , UserWarning )
225+ return False
226+
227+ # compare each element as float or str
228+ comparisons = []
229+ for ma_elem , gt_elem in zip (ma_elems , gt_elems ):
230+ if is_float (gt_elem ):
231+ normalized_ma_elem = normalize_number_str (ma_elem )
232+ comparisons .append (normalized_ma_elem == float (gt_elem ))
233+ else :
234+ # we do not remove punct since comparisons can include punct
235+ comparisons .append (
236+ normalize_str (ma_elem , remove_punct = False )
237+ == normalize_str (gt_elem , remove_punct = False )
238+ )
239+ return all (comparisons )
240+
241+ # if gt is a str
242+ else :
243+ logger .info (f"Evaluating { model_answer } as a string." )
244+ return normalize_str (model_answer ) == normalize_str (ground_truth )
245+
246+
247+ def normalize_str (input_str , remove_punct = True ) -> str :
248+ """
249+ Normalize a string by:
250+ - Removing all white spaces
251+ - Optionally removing punctuation (if remove_punct is True)
252+ - Converting to lowercase
253+ Parameters:
254+ - input_str: str, the string to normalize
255+ - remove_punct: bool, whether to remove punctuation (default: True)
256+ Returns:
257+ - str, the normalized string
258+ """
259+ # Remove all white spaces. Required e.g for seagull vs. sea gull
260+ no_spaces = re .sub (r"\s" , "" , input_str )
261+
262+ # Remove punctuation, if specified.
263+ if remove_punct :
264+ translator = str .maketrans ("" , "" , string .punctuation )
265+ return no_spaces .lower ().translate (translator )
266+ else :
267+ return no_spaces .lower ()
0 commit comments