Skip to content

Commit 048a622

Browse files
committed
Add initial implementation of ChangeSummarizer and EpisodeAnalysis classes
1 parent 57bf273 commit 048a622

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from dataclasses import dataclass
2+
from bgym import StepInfo
3+
4+
5+
def _diff(past_obs, current_obs):
6+
"""TODO: Implement the diff function.
7+
8+
Returns a diff version of current_obs compares to past_obs, unless there is too many changes.
9+
"""
10+
raise ValueError("Not implemented yet.")
11+
12+
13+
@dataclass
14+
class ChangeSummarizer:
15+
16+
llm: callable # language model
17+
obs_formatter: callable
18+
use_diff: bool = False
19+
20+
def summarize(
21+
self, past_obs: dict, action: str, current_obs: dict, past_summaries: list[str]
22+
) -> str:
23+
"""Produces, a summary of the effect of an action."""
24+
past_obs_message = self.obs_formatter(past_obs)
25+
current_obs_message = self.obs_formatter(current_obs)
26+
if self.use_diff:
27+
current_obs_message = _diff(past_obs_message, current_obs_message)
28+
29+
return self.llm(self.make_prompt(past_obs_message, current_obs_message, action))
30+
31+
def make_prompt(self, past_obs_message, action, current_obs_message, past_summaries):
32+
"""TODO: Implement the prompt."""
33+
return f"{past_obs_message} {action} {current_obs_message}"
34+
35+
36+
@dataclass
37+
class EpisodeAnalysis:
38+
analysis: str # complete analysis of the episode
39+
summary: str # short summary of the analysis
40+
categories: dict[str, float] # score for each category e.g. type of error or difficulty levels
41+
42+
43+
@dataclass
44+
class EpisodeSummarizer:
45+
46+
cange_summarizer: ChangeSummarizer = None
47+
48+
def summarize(episode: list[StepInfo]) -> EpisodeAnalysis:
49+
"""Run Change Summarizer for every step in the episode or extract a pre-computed one."""
50+
pass

0 commit comments

Comments
 (0)