Skip to content

Commit 7c01dfd

Browse files
committed
tweak
1 parent 9e1ff77 commit 7c01dfd

File tree

1 file changed

+1
-318
lines changed

1 file changed

+1
-318
lines changed

examples/agents_sdk/session_memory.ipynb

Lines changed: 1 addition & 318 deletions
Original file line numberDiff line numberDiff line change
@@ -805,324 +805,7 @@
805805
},
806806
{
807807
"cell_type": "code",
808-
"execution_count": null,
809-
"id": "4bb5c4e9",
810-
"metadata": {},
811-
"outputs": [],
812-
"source": [
813-
"import asyncio\n",
814-
"import itertools\n",
815-
"from collections import deque\n",
816-
"from typing import Optional, List, Tuple, Dict, Any\n",
817-
"\n",
818-
"class SummarizingSession:\n",
819-
" \"\"\"\n",
820-
" Keeps the last N *user* turns verbatim.\n",
821-
" Summarizes everything before that into a synthetic user→assistant pair.\n",
822-
" Internally stores (message, metadata) records. Exposes:\n",
823-
" - get_items(): model-safe messages only (no metadata)\n",
824-
" - get_full_history(): [{ \"message\": msg, \"metadata\": meta }, ...]\n",
825-
" \"\"\"\n",
826-
"\n",
827-
" # Only these keys are sent to the model. Everything else goes to metadata.\n",
828-
" _ALLOWED_MSG_KEYS = {\"role\", \"content\", \"name\"}\n",
829-
"\n",
830-
" def __init__(\n",
831-
" self,\n",
832-
" max_turns: int = 3,\n",
833-
" summarizer: Optional[\"Summarizer\"] = None,\n",
834-
" session_id: Optional[str] = None,\n",
835-
" ):\n",
836-
" assert max_turns >= 1\n",
837-
" self.max_turns = max_turns\n",
838-
" # Each record: {\"msg\": {...}, \"meta\": {...}}\n",
839-
" self._records: deque[Dict[str, Dict[str, Any]]] = deque()\n",
840-
" self._lock = asyncio.Lock()\n",
841-
" self.session_id = session_id or \"default\"\n",
842-
" self.summarizer = summarizer\n",
843-
"\n",
844-
" # --------- public API used by your runner ---------\n",
845-
"\n",
846-
" async def get_items(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:\n",
847-
" \"\"\"\n",
848-
" Returns messages in a model-safe shape (no metadata).\n",
849-
" Runner.run(..., session=self) should call this.\n",
850-
" \"\"\"\n",
851-
" async with self._lock:\n",
852-
" data = list(self._records)\n",
853-
" msgs = [self._sanitize_for_model(rec[\"msg\"]) for rec in data]\n",
854-
" return msgs[-limit:] if limit else msgs\n",
855-
"\n",
856-
" async def add_items(self, items: List[Dict[str, Any]]) -> None:\n",
857-
" async with self._lock:\n",
858-
" for it in items:\n",
859-
" msg, meta = self._split_msg_and_meta(it)\n",
860-
" self._records.append({\"msg\": msg, \"meta\": meta})\n",
861-
" need_summary, boundary_idx = self._should_summarize_locked()\n",
862-
"\n",
863-
" if need_summary:\n",
864-
" async with self._lock:\n",
865-
" prefix_records = list(itertools.islice(self._records, 0, boundary_idx))\n",
866-
" prefix_msgs = [r[\"msg\"] for r in prefix_records]\n",
867-
"\n",
868-
" user_shadow, assistant_summary = await self._summarize(prefix_msgs)\n",
869-
"\n",
870-
" async with self._lock:\n",
871-
" need_summary_now, boundary_idx_now = self._should_summarize_locked()\n",
872-
" if not need_summary_now:\n",
873-
" # normalize anyway if summarization got skipped\n",
874-
" self._normalize_synthetic_flags_locked()\n",
875-
" return\n",
876-
"\n",
877-
" suffix_records = list(itertools.islice(self._records, boundary_idx_now, None))\n",
878-
" self._records.clear()\n",
879-
"\n",
880-
" # Synthetic summary pair keeps synthetic=True\n",
881-
" self._records.extend([\n",
882-
" {\n",
883-
" \"msg\": {\"role\": \"user\", \"content\": user_shadow},\n",
884-
" \"meta\": {\n",
885-
" \"synthetic\": True,\n",
886-
" \"kind\": \"history_summary_prompt\",\n",
887-
" \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n",
888-
" },\n",
889-
" },\n",
890-
" {\n",
891-
" \"msg\": {\"role\": \"assistant\", \"content\": assistant_summary},\n",
892-
" \"meta\": {\n",
893-
" \"synthetic\": True,\n",
894-
" \"kind\": \"history_summary\",\n",
895-
" \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n",
896-
" },\n",
897-
" },\n",
898-
" ])\n",
899-
" self._records.extend(suffix_records)\n",
900-
"\n",
901-
" # ✅ Ensure all real messages explicitly have synthetic=False\n",
902-
" self._normalize_synthetic_flags_locked()\n",
903-
" else:\n",
904-
" # ✅ Even when we don't summarize, enforce the invariant\n",
905-
" async with self._lock:\n",
906-
" self._normalize_synthetic_flags_locked()\n",
907-
"\n",
908-
" async def pop_item(self) -> Optional[Dict[str, Any]]:\n",
909-
" async with self._lock:\n",
910-
" if not self._records:\n",
911-
" return None\n",
912-
" rec = self._records.pop()\n",
913-
" return dict(rec[\"msg\"]) # model-safe\n",
914-
"\n",
915-
" async def clear_session(self) -> None:\n",
916-
" async with self._lock:\n",
917-
" self._records.clear()\n",
918-
"\n",
919-
" def set_max_turns(self, n: int) -> None:\n",
920-
" assert n >= 1\n",
921-
" self.max_turns = n\n",
922-
"\n",
923-
" # --------- full-history (for debugging/analytics/observability) ---------\n",
924-
"\n",
925-
" # ✅ Backfill safeguard for older records that might lack the flag\n",
926-
" def _normalize_synthetic_flags_locked(self) -> None:\n",
927-
" for rec in self._records:\n",
928-
" role = rec[\"msg\"].get(\"role\")\n",
929-
" if role in (\"user\", \"assistant\") and \"synthetic\" not in rec[\"meta\"]:\n",
930-
" rec[\"meta\"][\"synthetic\"] = False\n",
931-
"\n",
932-
" \n",
933-
" async def get_full_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:\n",
934-
" \"\"\"\n",
935-
" Returns combined history where each entry is:\n",
936-
" { \"message\": {role, content[, name]}, \"metadata\": {...} }\n",
937-
" This is NOT sent to the model; it's for your logs/UI/debugging.\n",
938-
" \"\"\"\n",
939-
" async with self._lock:\n",
940-
" data = list(self._records)\n",
941-
" out = [{\"message\": dict(rec[\"msg\"]), \"metadata\": dict(rec[\"meta\"])} for rec in data]\n",
942-
" return out[-limit:] if limit else out\n",
943-
"\n",
944-
" # Backwards-compatible alias if you were using this name before\n",
945-
" async def get_items_with_metadata(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:\n",
946-
" return await self.get_full_history(limit)\n",
947-
"\n",
948-
" # --------- helpers ---------\n",
949-
"\n",
950-
" def _split_msg_and_meta(self, it: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n",
951-
" msg = {k: v for k, v in it.items() if k in self._ALLOWED_MSG_KEYS}\n",
952-
" extra = {k: v for k, v in it.items() if k not in self._ALLOWED_MSG_KEYS}\n",
953-
" meta = dict(extra.pop(\"metadata\", {}))\n",
954-
" meta.update(extra)\n",
955-
"\n",
956-
" if \"role\" not in msg or \"content\" not in msg:\n",
957-
" msg.setdefault(\"role\", \"user\")\n",
958-
" msg.setdefault(\"content\", str(it))\n",
959-
"\n",
960-
" # ✅ Default synthetic flag for real (non-summarized) messages\n",
961-
" role = msg.get(\"role\")\n",
962-
" if role in (\"user\", \"assistant\") and \"synthetic\" not in meta:\n",
963-
" meta[\"synthetic\"] = False\n",
964-
" return msg, meta\n",
965-
"\n",
966-
" def _sanitize_for_model(self, msg: Dict[str, Any]) -> Dict[str, Any]:\n",
967-
" \"\"\"\n",
968-
" Strictly keep only allowed keys for model input.\n",
969-
" \"\"\"\n",
970-
" return {k: v for k, v in msg.items() if k in self._ALLOWED_MSG_KEYS}\n",
971-
"\n",
972-
" def _is_user(self, rec: Dict[str, Dict[str, Any]]) -> bool:\n",
973-
" return rec[\"msg\"].get(\"role\") == \"user\"\n",
974-
"\n",
975-
" def _should_summarize_locked(self) -> Tuple[bool, int]:\n",
976-
" \"\"\"\n",
977-
" Find the earliest index among the last `max_turns` user messages.\n",
978-
" Everything before that index becomes the summarization prefix.\n",
979-
" \"\"\"\n",
980-
" idxs = []\n",
981-
" for i in range(len(self._records) - 1, -1, -1):\n",
982-
" if self._is_user(self._records[i]):\n",
983-
" idxs.append(i)\n",
984-
" if len(idxs) == self.max_turns:\n",
985-
" break\n",
986-
" if len(idxs) < self.max_turns:\n",
987-
" return False, -1\n",
988-
" boundary = min(idxs)\n",
989-
" if boundary <= 0:\n",
990-
" return False, -1\n",
991-
" return True, boundary\n",
992-
"\n",
993-
" async def _summarize(self, prefix_msgs: List[Dict[str, Any]]) -> Tuple[str, str]:\n",
994-
" \"\"\"\n",
995-
" Adapter to your summarizer. Provide *model-safe* messages only.\n",
996-
" \"\"\"\n",
997-
" if not self.summarizer:\n",
998-
" # Fallback summary if no summarizer is configured\n",
999-
" return (\"Summarize the conversation we had so far.\", \"Summary unavailable.\")\n",
1000-
" # Only send role/content/name to the summarizer as well\n",
1001-
" clean_prefix = [self._sanitize_for_model(m) for m in prefix_msgs]\n",
1002-
" return await self.summarizer.summarize(clean_prefix)\n"
1003-
]
1004-
},
1005-
{
1006-
"cell_type": "code",
1007-
"execution_count": 177,
1008-
"id": "a3e7cff8",
1009-
"metadata": {},
1010-
"outputs": [],
1011-
"source": [
1012-
"import asyncio\n",
1013-
"from collections import deque\n",
1014-
"from typing import Optional\n",
1015-
"import itertools\n",
1016-
"\n",
1017-
"class SummarizingSession:\n",
1018-
" \"\"\"\n",
1019-
" Keeps the last N user turns verbatim.\n",
1020-
" Summarizes everything before that into a synthetic user→assistant pair.\n",
1021-
" \"\"\"\n",
1022-
" def __init__(\n",
1023-
" self,\n",
1024-
" max_turns: int = 3,\n",
1025-
" summarizer: Optional[Summarizer] = None,\n",
1026-
" session_id: Optional[str] = None,\n",
1027-
" ):\n",
1028-
" assert max_turns >= 1\n",
1029-
" self.max_turns = max_turns\n",
1030-
" self._items: deque[Item] = deque()\n",
1031-
" self._lock = asyncio.Lock()\n",
1032-
" self.session_id = session_id or \"default\"\n",
1033-
" self.summarizer = summarizer\n",
1034-
"\n",
1035-
" # ----- public API that mirrors common Session interfaces -----\n",
1036-
"\n",
1037-
" async def get_items(self, limit: Optional[int] = None) -> list[Item]:\n",
1038-
" async with self._lock:\n",
1039-
" data = list(self._items)\n",
1040-
" return data[-limit:] if limit else data\n",
1041-
"\n",
1042-
" async def add_items(self, items: list[Item]) -> None:\n",
1043-
" # Append first\n",
1044-
" async with self._lock:\n",
1045-
" self._items.extend(items)\n",
1046-
" need_summary, boundary_idx = self._should_summarize_locked()\n",
1047-
"\n",
1048-
" # If we need a summary, **do it without the lock** to avoid blocking others\n",
1049-
" if need_summary:\n",
1050-
" # Take a snapshot of the prefix to summarize\n",
1051-
" async with self._lock:\n",
1052-
" prefix = list(itertools.islice(self._items, 0, boundary_idx))\n",
1053-
" # Produce the summary outside the lock\n",
1054-
" user_shadow, assistant_summary = await self.summarizer.summarize(prefix)\n",
1055-
"\n",
1056-
" # Re-acquire and re-check (in case of concurrent updates)\n",
1057-
" async with self._lock:\n",
1058-
" need_summary_now, boundary_idx_now = self._should_summarize_locked()\n",
1059-
" if need_summary_now:\n",
1060-
" suffix = list(itertools.islice(self._items, boundary_idx_now, None)) \n",
1061-
" self._items.clear()\n",
1062-
" self._items.extend([\n",
1063-
" {\n",
1064-
" \"role\": \"user\",\n",
1065-
" \"content\": user_shadow,\n",
1066-
" \"metadata\": {\n",
1067-
" \"synthetic\": True,\n",
1068-
" \"kind\": \"history_summary_prompt\",\n",
1069-
" \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n",
1070-
" },\n",
1071-
" },\n",
1072-
" {\n",
1073-
" \"role\": \"assistant\",\n",
1074-
" \"content\": assistant_summary,\n",
1075-
" \"metadata\": {\n",
1076-
" \"synthetic\": True,\n",
1077-
" \"kind\": \"history_summary\",\n",
1078-
" \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n",
1079-
" },\n",
1080-
" },\n",
1081-
" ])\n",
1082-
" self._items.extend(suffix)\n",
1083-
" # else: another concurrent writer already summarized; do nothing.\n",
1084-
"\n",
1085-
" async def pop_item(self) -> Optional[Item]:\n",
1086-
" async with self._lock:\n",
1087-
" return self._items.pop() if self._items else None\n",
1088-
"\n",
1089-
" async def clear_session(self) -> None:\n",
1090-
" async with self._lock:\n",
1091-
" self._items.clear()\n",
1092-
"\n",
1093-
" def set_max_turns(self, n: int) -> None:\n",
1094-
" assert n >= 1\n",
1095-
" self.max_turns = n\n",
1096-
"\n",
1097-
" # ----- helpers -----\n",
1098-
"\n",
1099-
" def _is_user(self, it: Item) -> bool:\n",
1100-
" return it.get(\"role\") == \"user\"\n",
1101-
"\n",
1102-
" def _should_summarize_locked(self) -> tuple[bool, int]:\n",
1103-
" \"\"\"\n",
1104-
" Returns (need_summary, boundary_idx).\n",
1105-
" boundary_idx = earliest index to keep (start of last N user turns).\n",
1106-
" If False, boundary_idx is undefined.\n",
1107-
" \"\"\"\n",
1108-
" idxs = []\n",
1109-
" for i in range(len(self._items) - 1, -1, -1):\n",
1110-
" if self._is_user(self._items[i]):\n",
1111-
" idxs.append(i)\n",
1112-
" if len(idxs) == self.max_turns:\n",
1113-
" break\n",
1114-
" if len(idxs) < self.max_turns:\n",
1115-
" return False, -1 # not enough user turns yet\n",
1116-
"\n",
1117-
" boundary = min(idxs) # earliest of the last N user turns\n",
1118-
" if boundary <= 0:\n",
1119-
" return False, -1 # nothing to summarize before boundary\n",
1120-
" return True, boundary\n"
1121-
]
1122-
},
1123-
{
1124-
"cell_type": "code",
1125-
"execution_count": 237,
808+
"execution_count": 250,
1126809
"id": "0d8bd4c5",
1127810
"metadata": {},
1128811
"outputs": [],

0 commit comments

Comments
 (0)