|
805 | 805 | },
|
806 | 806 | {
|
807 | 807 | "cell_type": "code",
|
808 |
| - "execution_count": null, |
809 |
| - "id": "4bb5c4e9", |
810 |
| - "metadata": {}, |
811 |
| - "outputs": [], |
812 |
| - "source": [ |
813 |
| - "import asyncio\n", |
814 |
| - "import itertools\n", |
815 |
| - "from collections import deque\n", |
816 |
| - "from typing import Optional, List, Tuple, Dict, Any\n", |
817 |
| - "\n", |
818 |
| - "class SummarizingSession:\n", |
819 |
| - " \"\"\"\n", |
820 |
| - " Keeps the last N *user* turns verbatim.\n", |
821 |
| - " Summarizes everything before that into a synthetic user→assistant pair.\n", |
822 |
| - " Internally stores (message, metadata) records. Exposes:\n", |
823 |
| - " - get_items(): model-safe messages only (no metadata)\n", |
824 |
| - " - get_full_history(): [{ \"message\": msg, \"metadata\": meta }, ...]\n", |
825 |
| - " \"\"\"\n", |
826 |
| - "\n", |
827 |
| - " # Only these keys are sent to the model. Everything else goes to metadata.\n", |
828 |
| - " _ALLOWED_MSG_KEYS = {\"role\", \"content\", \"name\"}\n", |
829 |
| - "\n", |
830 |
| - " def __init__(\n", |
831 |
| - " self,\n", |
832 |
| - " max_turns: int = 3,\n", |
833 |
| - " summarizer: Optional[\"Summarizer\"] = None,\n", |
834 |
| - " session_id: Optional[str] = None,\n", |
835 |
| - " ):\n", |
836 |
| - " assert max_turns >= 1\n", |
837 |
| - " self.max_turns = max_turns\n", |
838 |
| - " # Each record: {\"msg\": {...}, \"meta\": {...}}\n", |
839 |
| - " self._records: deque[Dict[str, Dict[str, Any]]] = deque()\n", |
840 |
| - " self._lock = asyncio.Lock()\n", |
841 |
| - " self.session_id = session_id or \"default\"\n", |
842 |
| - " self.summarizer = summarizer\n", |
843 |
| - "\n", |
844 |
| - " # --------- public API used by your runner ---------\n", |
845 |
| - "\n", |
846 |
| - " async def get_items(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:\n", |
847 |
| - " \"\"\"\n", |
848 |
| - " Returns messages in a model-safe shape (no metadata).\n", |
849 |
| - " Runner.run(..., session=self) should call this.\n", |
850 |
| - " \"\"\"\n", |
851 |
| - " async with self._lock:\n", |
852 |
| - " data = list(self._records)\n", |
853 |
| - " msgs = [self._sanitize_for_model(rec[\"msg\"]) for rec in data]\n", |
854 |
| - " return msgs[-limit:] if limit else msgs\n", |
855 |
| - "\n", |
856 |
| - " async def add_items(self, items: List[Dict[str, Any]]) -> None:\n", |
857 |
| - " async with self._lock:\n", |
858 |
| - " for it in items:\n", |
859 |
| - " msg, meta = self._split_msg_and_meta(it)\n", |
860 |
| - " self._records.append({\"msg\": msg, \"meta\": meta})\n", |
861 |
| - " need_summary, boundary_idx = self._should_summarize_locked()\n", |
862 |
| - "\n", |
863 |
| - " if need_summary:\n", |
864 |
| - " async with self._lock:\n", |
865 |
| - " prefix_records = list(itertools.islice(self._records, 0, boundary_idx))\n", |
866 |
| - " prefix_msgs = [r[\"msg\"] for r in prefix_records]\n", |
867 |
| - "\n", |
868 |
| - " user_shadow, assistant_summary = await self._summarize(prefix_msgs)\n", |
869 |
| - "\n", |
870 |
| - " async with self._lock:\n", |
871 |
| - " need_summary_now, boundary_idx_now = self._should_summarize_locked()\n", |
872 |
| - " if not need_summary_now:\n", |
873 |
| - " # normalize anyway if summarization got skipped\n", |
874 |
| - " self._normalize_synthetic_flags_locked()\n", |
875 |
| - " return\n", |
876 |
| - "\n", |
877 |
| - " suffix_records = list(itertools.islice(self._records, boundary_idx_now, None))\n", |
878 |
| - " self._records.clear()\n", |
879 |
| - "\n", |
880 |
| - " # Synthetic summary pair keeps synthetic=True\n", |
881 |
| - " self._records.extend([\n", |
882 |
| - " {\n", |
883 |
| - " \"msg\": {\"role\": \"user\", \"content\": user_shadow},\n", |
884 |
| - " \"meta\": {\n", |
885 |
| - " \"synthetic\": True,\n", |
886 |
| - " \"kind\": \"history_summary_prompt\",\n", |
887 |
| - " \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n", |
888 |
| - " },\n", |
889 |
| - " },\n", |
890 |
| - " {\n", |
891 |
| - " \"msg\": {\"role\": \"assistant\", \"content\": assistant_summary},\n", |
892 |
| - " \"meta\": {\n", |
893 |
| - " \"synthetic\": True,\n", |
894 |
| - " \"kind\": \"history_summary\",\n", |
895 |
| - " \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n", |
896 |
| - " },\n", |
897 |
| - " },\n", |
898 |
| - " ])\n", |
899 |
| - " self._records.extend(suffix_records)\n", |
900 |
| - "\n", |
901 |
| - " # ✅ Ensure all real messages explicitly have synthetic=False\n", |
902 |
| - " self._normalize_synthetic_flags_locked()\n", |
903 |
| - " else:\n", |
904 |
| - " # ✅ Even when we don't summarize, enforce the invariant\n", |
905 |
| - " async with self._lock:\n", |
906 |
| - " self._normalize_synthetic_flags_locked()\n", |
907 |
| - "\n", |
908 |
| - " async def pop_item(self) -> Optional[Dict[str, Any]]:\n", |
909 |
| - " async with self._lock:\n", |
910 |
| - " if not self._records:\n", |
911 |
| - " return None\n", |
912 |
| - " rec = self._records.pop()\n", |
913 |
| - " return dict(rec[\"msg\"]) # model-safe\n", |
914 |
| - "\n", |
915 |
| - " async def clear_session(self) -> None:\n", |
916 |
| - " async with self._lock:\n", |
917 |
| - " self._records.clear()\n", |
918 |
| - "\n", |
919 |
| - " def set_max_turns(self, n: int) -> None:\n", |
920 |
| - " assert n >= 1\n", |
921 |
| - " self.max_turns = n\n", |
922 |
| - "\n", |
923 |
| - " # --------- full-history (for debugging/analytics/observability) ---------\n", |
924 |
| - "\n", |
925 |
| - " # ✅ Backfill safeguard for older records that might lack the flag\n", |
926 |
| - " def _normalize_synthetic_flags_locked(self) -> None:\n", |
927 |
| - " for rec in self._records:\n", |
928 |
| - " role = rec[\"msg\"].get(\"role\")\n", |
929 |
| - " if role in (\"user\", \"assistant\") and \"synthetic\" not in rec[\"meta\"]:\n", |
930 |
| - " rec[\"meta\"][\"synthetic\"] = False\n", |
931 |
| - "\n", |
932 |
| - " \n", |
933 |
| - " async def get_full_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:\n", |
934 |
| - " \"\"\"\n", |
935 |
| - " Returns combined history where each entry is:\n", |
936 |
| - " { \"message\": {role, content[, name]}, \"metadata\": {...} }\n", |
937 |
| - " This is NOT sent to the model; it's for your logs/UI/debugging.\n", |
938 |
| - " \"\"\"\n", |
939 |
| - " async with self._lock:\n", |
940 |
| - " data = list(self._records)\n", |
941 |
| - " out = [{\"message\": dict(rec[\"msg\"]), \"metadata\": dict(rec[\"meta\"])} for rec in data]\n", |
942 |
| - " return out[-limit:] if limit else out\n", |
943 |
| - "\n", |
944 |
| - " # Backwards-compatible alias if you were using this name before\n", |
945 |
| - " async def get_items_with_metadata(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:\n", |
946 |
| - " return await self.get_full_history(limit)\n", |
947 |
| - "\n", |
948 |
| - " # --------- helpers ---------\n", |
949 |
| - "\n", |
950 |
| - " def _split_msg_and_meta(self, it: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n", |
951 |
| - " msg = {k: v for k, v in it.items() if k in self._ALLOWED_MSG_KEYS}\n", |
952 |
| - " extra = {k: v for k, v in it.items() if k not in self._ALLOWED_MSG_KEYS}\n", |
953 |
| - " meta = dict(extra.pop(\"metadata\", {}))\n", |
954 |
| - " meta.update(extra)\n", |
955 |
| - "\n", |
956 |
| - " if \"role\" not in msg or \"content\" not in msg:\n", |
957 |
| - " msg.setdefault(\"role\", \"user\")\n", |
958 |
| - " msg.setdefault(\"content\", str(it))\n", |
959 |
| - "\n", |
960 |
| - " # ✅ Default synthetic flag for real (non-summarized) messages\n", |
961 |
| - " role = msg.get(\"role\")\n", |
962 |
| - " if role in (\"user\", \"assistant\") and \"synthetic\" not in meta:\n", |
963 |
| - " meta[\"synthetic\"] = False\n", |
964 |
| - " return msg, meta\n", |
965 |
| - "\n", |
966 |
| - " def _sanitize_for_model(self, msg: Dict[str, Any]) -> Dict[str, Any]:\n", |
967 |
| - " \"\"\"\n", |
968 |
| - " Strictly keep only allowed keys for model input.\n", |
969 |
| - " \"\"\"\n", |
970 |
| - " return {k: v for k, v in msg.items() if k in self._ALLOWED_MSG_KEYS}\n", |
971 |
| - "\n", |
972 |
| - " def _is_user(self, rec: Dict[str, Dict[str, Any]]) -> bool:\n", |
973 |
| - " return rec[\"msg\"].get(\"role\") == \"user\"\n", |
974 |
| - "\n", |
975 |
| - " def _should_summarize_locked(self) -> Tuple[bool, int]:\n", |
976 |
| - " \"\"\"\n", |
977 |
| - " Find the earliest index among the last `max_turns` user messages.\n", |
978 |
| - " Everything before that index becomes the summarization prefix.\n", |
979 |
| - " \"\"\"\n", |
980 |
| - " idxs = []\n", |
981 |
| - " for i in range(len(self._records) - 1, -1, -1):\n", |
982 |
| - " if self._is_user(self._records[i]):\n", |
983 |
| - " idxs.append(i)\n", |
984 |
| - " if len(idxs) == self.max_turns:\n", |
985 |
| - " break\n", |
986 |
| - " if len(idxs) < self.max_turns:\n", |
987 |
| - " return False, -1\n", |
988 |
| - " boundary = min(idxs)\n", |
989 |
| - " if boundary <= 0:\n", |
990 |
| - " return False, -1\n", |
991 |
| - " return True, boundary\n", |
992 |
| - "\n", |
993 |
| - " async def _summarize(self, prefix_msgs: List[Dict[str, Any]]) -> Tuple[str, str]:\n", |
994 |
| - " \"\"\"\n", |
995 |
| - " Adapter to your summarizer. Provide *model-safe* messages only.\n", |
996 |
| - " \"\"\"\n", |
997 |
| - " if not self.summarizer:\n", |
998 |
| - " # Fallback summary if no summarizer is configured\n", |
999 |
| - " return (\"Summarize the conversation we had so far.\", \"Summary unavailable.\")\n", |
1000 |
| - " # Only send role/content/name to the summarizer as well\n", |
1001 |
| - " clean_prefix = [self._sanitize_for_model(m) for m in prefix_msgs]\n", |
1002 |
| - " return await self.summarizer.summarize(clean_prefix)\n" |
1003 |
| - ] |
1004 |
| - }, |
1005 |
| - { |
1006 |
| - "cell_type": "code", |
1007 |
| - "execution_count": 177, |
1008 |
| - "id": "a3e7cff8", |
1009 |
| - "metadata": {}, |
1010 |
| - "outputs": [], |
1011 |
| - "source": [ |
1012 |
| - "import asyncio\n", |
1013 |
| - "from collections import deque\n", |
1014 |
| - "from typing import Optional\n", |
1015 |
| - "import itertools\n", |
1016 |
| - "\n", |
1017 |
| - "class SummarizingSession:\n", |
1018 |
| - " \"\"\"\n", |
1019 |
| - " Keeps the last N user turns verbatim.\n", |
1020 |
| - " Summarizes everything before that into a synthetic user→assistant pair.\n", |
1021 |
| - " \"\"\"\n", |
1022 |
| - " def __init__(\n", |
1023 |
| - " self,\n", |
1024 |
| - " max_turns: int = 3,\n", |
1025 |
| - " summarizer: Optional[Summarizer] = None,\n", |
1026 |
| - " session_id: Optional[str] = None,\n", |
1027 |
| - " ):\n", |
1028 |
| - " assert max_turns >= 1\n", |
1029 |
| - " self.max_turns = max_turns\n", |
1030 |
| - " self._items: deque[Item] = deque()\n", |
1031 |
| - " self._lock = asyncio.Lock()\n", |
1032 |
| - " self.session_id = session_id or \"default\"\n", |
1033 |
| - " self.summarizer = summarizer\n", |
1034 |
| - "\n", |
1035 |
| - " # ----- public API that mirrors common Session interfaces -----\n", |
1036 |
| - "\n", |
1037 |
| - " async def get_items(self, limit: Optional[int] = None) -> list[Item]:\n", |
1038 |
| - " async with self._lock:\n", |
1039 |
| - " data = list(self._items)\n", |
1040 |
| - " return data[-limit:] if limit else data\n", |
1041 |
| - "\n", |
1042 |
| - " async def add_items(self, items: list[Item]) -> None:\n", |
1043 |
| - " # Append first\n", |
1044 |
| - " async with self._lock:\n", |
1045 |
| - " self._items.extend(items)\n", |
1046 |
| - " need_summary, boundary_idx = self._should_summarize_locked()\n", |
1047 |
| - "\n", |
1048 |
| - " # If we need a summary, **do it without the lock** to avoid blocking others\n", |
1049 |
| - " if need_summary:\n", |
1050 |
| - " # Take a snapshot of the prefix to summarize\n", |
1051 |
| - " async with self._lock:\n", |
1052 |
| - " prefix = list(itertools.islice(self._items, 0, boundary_idx))\n", |
1053 |
| - " # Produce the summary outside the lock\n", |
1054 |
| - " user_shadow, assistant_summary = await self.summarizer.summarize(prefix)\n", |
1055 |
| - "\n", |
1056 |
| - " # Re-acquire and re-check (in case of concurrent updates)\n", |
1057 |
| - " async with self._lock:\n", |
1058 |
| - " need_summary_now, boundary_idx_now = self._should_summarize_locked()\n", |
1059 |
| - " if need_summary_now:\n", |
1060 |
| - " suffix = list(itertools.islice(self._items, boundary_idx_now, None)) \n", |
1061 |
| - " self._items.clear()\n", |
1062 |
| - " self._items.extend([\n", |
1063 |
| - " {\n", |
1064 |
| - " \"role\": \"user\",\n", |
1065 |
| - " \"content\": user_shadow,\n", |
1066 |
| - " \"metadata\": {\n", |
1067 |
| - " \"synthetic\": True,\n", |
1068 |
| - " \"kind\": \"history_summary_prompt\",\n", |
1069 |
| - " \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n", |
1070 |
| - " },\n", |
1071 |
| - " },\n", |
1072 |
| - " {\n", |
1073 |
| - " \"role\": \"assistant\",\n", |
1074 |
| - " \"content\": assistant_summary,\n", |
1075 |
| - " \"metadata\": {\n", |
1076 |
| - " \"synthetic\": True,\n", |
1077 |
| - " \"kind\": \"history_summary\",\n", |
1078 |
| - " \"summary_for_turns\": f\"< all before idx {boundary_idx_now} >\",\n", |
1079 |
| - " },\n", |
1080 |
| - " },\n", |
1081 |
| - " ])\n", |
1082 |
| - " self._items.extend(suffix)\n", |
1083 |
| - " # else: another concurrent writer already summarized; do nothing.\n", |
1084 |
| - "\n", |
1085 |
| - " async def pop_item(self) -> Optional[Item]:\n", |
1086 |
| - " async with self._lock:\n", |
1087 |
| - " return self._items.pop() if self._items else None\n", |
1088 |
| - "\n", |
1089 |
| - " async def clear_session(self) -> None:\n", |
1090 |
| - " async with self._lock:\n", |
1091 |
| - " self._items.clear()\n", |
1092 |
| - "\n", |
1093 |
| - " def set_max_turns(self, n: int) -> None:\n", |
1094 |
| - " assert n >= 1\n", |
1095 |
| - " self.max_turns = n\n", |
1096 |
| - "\n", |
1097 |
| - " # ----- helpers -----\n", |
1098 |
| - "\n", |
1099 |
| - " def _is_user(self, it: Item) -> bool:\n", |
1100 |
| - " return it.get(\"role\") == \"user\"\n", |
1101 |
| - "\n", |
1102 |
| - " def _should_summarize_locked(self) -> tuple[bool, int]:\n", |
1103 |
| - " \"\"\"\n", |
1104 |
| - " Returns (need_summary, boundary_idx).\n", |
1105 |
| - " boundary_idx = earliest index to keep (start of last N user turns).\n", |
1106 |
| - " If False, boundary_idx is undefined.\n", |
1107 |
| - " \"\"\"\n", |
1108 |
| - " idxs = []\n", |
1109 |
| - " for i in range(len(self._items) - 1, -1, -1):\n", |
1110 |
| - " if self._is_user(self._items[i]):\n", |
1111 |
| - " idxs.append(i)\n", |
1112 |
| - " if len(idxs) == self.max_turns:\n", |
1113 |
| - " break\n", |
1114 |
| - " if len(idxs) < self.max_turns:\n", |
1115 |
| - " return False, -1 # not enough user turns yet\n", |
1116 |
| - "\n", |
1117 |
| - " boundary = min(idxs) # earliest of the last N user turns\n", |
1118 |
| - " if boundary <= 0:\n", |
1119 |
| - " return False, -1 # nothing to summarize before boundary\n", |
1120 |
| - " return True, boundary\n" |
1121 |
| - ] |
1122 |
| - }, |
1123 |
| - { |
1124 |
| - "cell_type": "code", |
1125 |
| - "execution_count": 237, |
| 808 | + "execution_count": 250, |
1126 | 809 | "id": "0d8bd4c5",
|
1127 | 810 | "metadata": {},
|
1128 | 811 | "outputs": [],
|
|
0 commit comments