Skip to content

Commit 58c69c4

Browse files
committed
clean up loop.py
1 parent 017c203 commit 58c69c4

File tree

2 files changed

+258
-232
lines changed

2 files changed

+258
-232
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
from dataclasses import asdict, is_dataclass
3+
4+
import numpy as np
5+
from tapeagents.core import Step, StepMetadata
6+
from tapeagents.dialog_tape import AssistantStep, AssistantThought
7+
from tapeagents.io import save_json_tape, save_tape_images
8+
9+
from agentlab.agents.tapeagent.agent import DictObservation, Tape, TapeAgent
10+
11+
__all__ = ["as_tape", "save_tape", "TapeAgent", "Tape"]
12+
13+
14+
def as_tape(steps_info: list) -> Tape:
15+
"""
16+
Create a Tape object from the steps info.
17+
18+
Args:
19+
steps_info: list of StepInfo objects.
20+
21+
Returns:
22+
Tape: a Tape object containing the steps and metadata.
23+
"""
24+
25+
class JsonEncoder(json.JSONEncoder):
26+
def default(self, obj):
27+
if is_dataclass(obj):
28+
return asdict(obj) # type: ignore
29+
if isinstance(obj, np.integer):
30+
return int(obj)
31+
if isinstance(obj, np.floating):
32+
return float(obj)
33+
if isinstance(obj, np.ndarray):
34+
return obj.tolist()
35+
return super().default(obj)
36+
37+
steps: list[Step] = []
38+
for step_info in steps_info:
39+
if step_info.obs is not None:
40+
json_obs = json.dumps(step_info.obs, cls=JsonEncoder)
41+
steps.append(DictObservation(content=json_obs))
42+
if thought := step_info.agent_info.get("think"):
43+
steps.append(AssistantThought(content=thought))
44+
if step_info.action is not None:
45+
step_metadata = StepMetadata(
46+
other=dict(
47+
reward=step_info.reward,
48+
raw_reward=step_info.raw_reward,
49+
terminated=step_info.terminated,
50+
truncated=step_info.truncated,
51+
agent_info=step_info.agent_info,
52+
stats=step_info.stats,
53+
)
54+
)
55+
steps.append(AssistantStep(content=step_info.action, metadata=step_metadata))
56+
return Tape(steps=steps)
57+
58+
59+
def save_tape(exp_dir: str, episode_info: list, task: dict, tape: Tape):
60+
tape.metadata.reward = sum([step.reward for step in episode_info])
61+
tape.metadata.truncated = episode_info[-1].truncated
62+
tape.metadata.terminated = episode_info[-1].terminated
63+
tape.metadata.task = task
64+
save_json_tape(tape, exp_dir, "tape.json")
65+
save_tape_images(tape, f"{exp_dir}/tape_attachments")

0 commit comments

Comments
 (0)