Skip to content

Commit 3138c7c

Browse files
addressing comments and fixing impl
1 parent 7485c14 commit 3138c7c

File tree

4 files changed

+20
-50
lines changed

4 files changed

+20
-50
lines changed

src/lighteval/tasks/tasks/long_horizon_execution/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
PROMPT_TEMPLATE_MULTI_START = """You are an AI assistant. I will provide you with a dictionary and then give you keys in groups of {k}.
3030
Your task is to keep a running total (starting from 0) by adding the values associated with the keys I provide.
31-
In each turn, I'll provide {k} keys (comma-separated).
31+
In each turn, I'll provide {k} key(s) (comma-separated).
3232
Respond with the current running sum, enclosed in <answer> tags.
3333
3434
Dictionary to maintain:

src/lighteval/tasks/tasks/long_horizon_execution/main.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,37 +44,11 @@
4444
from lighteval.metrics.metrics import Metrics
4545
from lighteval.tasks.lighteval_task import LightevalTaskConfig
4646
from lighteval.tasks.requests import Doc
47-
from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES
47+
from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES, PROMPT_TEMPLATE_SINGLE
4848
from lighteval.tasks.tasks.long_horizon_execution.multi_turn import create_multi_turn_tasks
4949
from lighteval.tasks.tasks.long_horizon_execution.utils import _build_prompt_and_target
5050

5151

52-
# Single-turn prompt template
53-
PROMPT_TEMPLATE_SINGLE = """You are an AI assistant. I will provide you with a dictionary and then give you a list of keys.
54-
Your task is to calculate the final cumulative sum after processing all keys in order.
55-
56-
For each key in the list, you need to:
57-
1. Look up the value in the dictionary
58-
2. Add it to the running sum
59-
3. After processing all keys, output the final cumulative sum
60-
61-
Dictionary to use:
62-
{dict_str}
63-
64-
Keys to process in order:
65-
{keys_str}
66-
67-
Your task: Process all keys in order and calculate the final cumulative sum after processing all {num_keys} keys.
68-
69-
IMPORTANT:
70-
- Output your answer as a single integer value inside <answer></answer> tags
71-
- Do not include any other text outside the answer tags
72-
- Format: <answer>final_sum</answer>
73-
- Example: If the final cumulative sum is 42, output: <answer>42</answer>
74-
75-
Your answer:"""
76-
77-
7852
def single_turn_prompt_function(line, prompt_length=32768, task_name: str = None):
7953
"""
8054
Prompt function for single-turn evaluation (non-inspect-ai backend).

src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ def _extract_response_content(response):
6767
return str(response)
6868

6969

70-
async def _process_single_turn(state, turn_chunk, generate):
70+
async def _process_single_turn(state, turn_chunk, generate_fn):
7171
"""Process a single turn: add user message, get model response, add assistant message."""
7272
keys_str = ", ".join(turn_chunk)
7373
followup_prompt = PROMPT_TEMPLATE_MULTI_FOLLOWUP.format(keys_str=keys_str)
7474
state.messages.append(ChatMessageUser(content=followup_prompt))
7575

76-
# generate() takes the state and returns updated state with assistant message added
77-
updated_state = await generate(state)
76+
# generate_fn() takes the state and returns updated state with assistant message added
77+
updated_state = await generate_fn(state)
7878
turn_response = _extract_response_content(updated_state.output.completion if updated_state.output else "")
7979

8080
return updated_state, turn_response
@@ -91,7 +91,7 @@ def multi_turn_solver():
9191
async def solve(state: TaskState, generate: Generate):
9292
turn_chunks = state.metadata.get("turn_chunks", [])
9393

94-
if not turn_chunks or len(turn_chunks) == 0:
94+
if not turn_chunks:
9595
return state
9696

9797
# Initialize messages
@@ -129,7 +129,7 @@ async def solve(state: TaskState, generate: Generate):
129129
return solve
130130

131131

132-
@scorer(metrics={"turn_accuracy": [accuracy(), stderr()], "fractional_accuracy": [accuracy(), stderr()]})
132+
@scorer(metrics={"fractional_accuracy": [accuracy(), stderr()]})
133133
def multi_turn_scorer():
134134
"""
135135
Scorer for multi-turn Long Horizon Execution task.
@@ -143,19 +143,23 @@ async def score(state: TaskState, target: Target):
143143
expected_per_turn = state.metadata.get("expected_per_turn", [])
144144

145145
if not all_turn_outputs:
146-
return Score(value=0.0, answer="", explanation="No turn outputs found in state.metadata")
146+
return Score(
147+
value={"fractional_accuracy": 0.0},
148+
answer="",
149+
explanation="No turn outputs found in state.metadata",
150+
)
147151

148152
if len(all_turn_outputs) != len(expected_per_turn):
149153
return Score(
150-
value=0.0,
154+
value={"fractional_accuracy": 0.0},
151155
answer="",
152156
explanation=f"Mismatch: {len(all_turn_outputs)} outputs vs {len(expected_per_turn)} expected turns",
153157
)
154158

155159
parsed_outputs = []
156160
answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
157161

158-
for turn_idx, turn_output in enumerate(all_turn_outputs):
162+
for turn_output in all_turn_outputs:
159163
match = answer_pattern.search(turn_output)
160164
if match:
161165
try:
@@ -177,12 +181,7 @@ async def score(state: TaskState, target: Target):
177181
fractional_accuracy = correct_turns / len(expected_per_turn) if expected_per_turn else 0.0
178182

179183
return Score(
180-
value={
181-
"turn_accuracy": fractional_accuracy,
182-
"fractional_accuracy": fractional_accuracy,
183-
"correct_turns": correct_turns,
184-
"total_turns": len(expected_per_turn),
185-
},
184+
value={"fractional_accuracy": fractional_accuracy},
186185
answer=str(parsed_outputs),
187186
explanation=f"Correct {correct_turns}/{len(expected_per_turn)} turns. Details: {turn_results}",
188187
)

src/lighteval/tasks/tasks/long_horizon_execution/utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def build_initial_prompt_for_n(n):
126126
keys_str = ", ".join(first_turn_keys)
127127

128128
return PROMPT_TEMPLATE_MULTI_START.format(
129-
dict_str=dict_str, keys_str=keys_str, k=k, num_keys=len(first_turn_keys)
129+
dict_str=dict_str,
130+
keys_str=keys_str,
131+
k=k,
130132
)
131133

132134
return _binary_search_max_items(input_keys, build_initial_prompt_for_n, prompt_length, min_items=k)
@@ -169,7 +171,6 @@ def _build_multi_turn_prompts(record, prompt_length=32768, k=1):
169171
"""
170172
input_keys = record["input"]
171173
input_values = record["values"]
172-
expected_output = record["output"]
173174

174175
# Handle empty input case
175176
if len(input_keys) == 0:
@@ -181,21 +182,17 @@ def _build_multi_turn_prompts(record, prompt_length=32768, k=1):
181182
# Use the maximum n that fits
182183
input_keys = input_keys[:max_n]
183184
input_values = input_values[:max_n]
184-
expected_output = expected_output[:max_n]
185185

186-
turn_chunks, value_chunks, expected_per_turn = _chunk_and_calculate_expected(input_keys, input_values, k)
186+
turn_chunks, _, expected_per_turn = _chunk_and_calculate_expected(input_keys, input_values, k)
187187

188188
dictionary = dict(zip(input_keys, input_values))
189189
dict_str = str(dictionary)
190190

191191
first_turn_keys_str = ", ".join(turn_chunks[0])
192-
initial_prompt = PROMPT_TEMPLATE_MULTI_START.format(
193-
dict_str=dict_str, keys_str=first_turn_keys_str, k=k, num_keys=len(turn_chunks[0])
194-
)
192+
initial_prompt = PROMPT_TEMPLATE_MULTI_START.format(dict_str=dict_str, keys_str=first_turn_keys_str, k=k)
195193

196194
metadata = {
197195
"turn_chunks": turn_chunks,
198-
"value_chunks": value_chunks,
199196
"expected_per_turn": expected_per_turn,
200197
"dictionary": dictionary,
201198
"k": k,

0 commit comments

Comments
 (0)