diff --git a/src/wa_action_traces.py b/src/wa_action_traces.py index 8ff5b84..69ade37 100644 --- a/src/wa_action_traces.py +++ b/src/wa_action_traces.py @@ -111,8 +111,15 @@ def extract_trace(task_cls, headless=True): monkey_patch_playwright(observation_callback=env._get_obs, trace_storage=trace) env.reset() - env.task.cheat(env.page, env.chat.messages) - env.close() + # For compositional tasks, we need to cheat on each subtask + if hasattr(env.task, 'subtasks'): + # This is a compositional task, solve each subtask + for subtask_idx in range(len(env.task.subtasks)): + env.task.cheat(env.page, env.chat.messages, subtask_idx) + else: + # This is a regular task + env.task.cheat(env.page, env.chat.messages) + env.close() return trace