From f550bae482473bc9b4051edc419597a712518ca0 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 21 Nov 2024 10:25:36 -0600 Subject: [PATCH 1/3] Rename .turns() -> .get_turns(); .last_turn() -> .get_last_turn() --- chatlas/_chat.py | 24 +++++++++++++++--------- chatlas/_turn.py | 4 ++-- tests/conftest.py | 14 +++++++------- tests/test_chat.py | 16 ++++++++-------- tests/test_provider_anthropic.py | 2 +- tests/test_provider_azure.py | 4 ++-- tests/test_provider_bedrock.py | 2 +- tests/test_provider_google.py | 2 +- tests/test_provider_openai.py | 4 ++-- 9 files changed, 39 insertions(+), 33 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 0ed49bde..04559ed2 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -81,7 +81,7 @@ def __init__( self._turns: list[Turn] = list(turns or []) self.tools: dict[str, Tool] = {} - def turns( + def get_turns( self, *, include_system_prompt: bool = False, @@ -102,7 +102,7 @@ def turns( return self._turns[1:] return self._turns - def last_turn( + def get_last_turn( self, *, role: Literal["assistant", "user", "system"] = "assistant", @@ -145,7 +145,12 @@ def set_turns(self, turns: Sequence[Turn]): @property def system_prompt(self) -> str | None: """ - Get the system prompt for the chat. + A property to get (or set) the system prompt for the chat. + + Returns + ------- + str | None + The system prompt (if any). """ if self._turns and self._turns[0].role == "system": return self._turns[0].text @@ -211,7 +216,8 @@ def server(input): # noqa: A002 chat = ui.Chat( "chat", messages=[ - {"role": turn.role, "content": turn.text} for turn in self.turns() + {"role": turn.role, "content": turn.text} + for turn in self.get_turns() ], ) @@ -465,7 +471,7 @@ def extract_data( for _ in response: pass - turn = self.last_turn() + turn = self.get_last_turn() assert turn is not None res: list[ContentJson] = [] @@ -519,7 +525,7 @@ async def extract_data_async( async for _ in response: pass - turn = self.last_turn() + turn = self.get_last_turn() assert turn is not None res: list[ContentJson] = [] @@ -782,7 +788,7 @@ async def _submit_turns_async( self._turns.extend([user_turn, turn]) def _invoke_tools(self) -> Turn | None: - turn = self.last_turn() + turn = self.get_last_turn() if turn is None: return None @@ -799,7 +805,7 @@ def _invoke_tools(self) -> Turn | None: return Turn("user", results) async def _invoke_tools_async(self) -> Turn | None: - turn = self.last_turn() + turn = self.get_last_turn() if turn is None: return None @@ -854,7 +860,7 @@ async def _invoke_tool_async( return ContentToolResult(id_, None, str(e)) def __str__(self): - turns = self.turns(include_system_prompt=True) + turns = self.get_turns(include_system_prompt=True) tokens = sum(sum(turn.tokens) for turn in turns) output = f"\n" for turn in turns: diff --git a/chatlas/_turn.py b/chatlas/_turn.py index e2b0f8dc..a6883727 100644 --- a/chatlas/_turn.py +++ b/chatlas/_turn.py @@ -31,7 +31,7 @@ class Turn: chat = ChatOpenAI() str(chat.chat("What is the capital of France?")) - turns = chat.turns() + turns = chat.get_turns() assert len(turns) == 2 assert isinstance(turns[0], Turn) assert turns[0].role == "user" @@ -39,7 +39,7 @@ class Turn: # Load context into a new chat instance chat2 = ChatAnthropic(turns=turns) - turns2 = chat2.turns() + turns2 = chat2.get_turns() assert turns == turns2 ``` diff --git a/tests/conftest.py b/tests/conftest.py index 15d5f268..8e2d7b4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,13 +50,13 @@ def assert_turns_system(chat_fun: ChatFun): chat = chat_fun(system_prompt=system_prompt) response = chat.chat("What is the name of Winnie the Pooh's human friend?") response_text = str(response) - assert len(chat.turns()) == 2 + assert len(chat.get_turns()) == 2 assert "CHRISTOPHER ROBIN" in response_text chat = chat_fun(turns=[Turn("system", system_prompt)]) response = chat.chat("What is the name of Winnie the Pooh's human friend?") assert "CHRISTOPHER ROBIN" in str(response) - assert len(chat.turns()) == 2 + assert len(chat.get_turns()) == 2 def assert_turns_existing(chat_fun: ChatFun): @@ -70,11 +70,11 @@ def assert_turns_existing(chat_fun: ChatFun): ), ] ) - assert len(chat.turns()) == 2 + assert len(chat.get_turns()) == 2 response = chat.chat("Who is the remaining one? Just give the name") assert "Prancer" in str(response) - assert len(chat.turns()) == 4 + assert len(chat.get_turns()) == 4 def assert_tools_simple(chat_fun: ChatFun, stream: bool = True): @@ -133,7 +133,7 @@ def favorite_color(person: str): assert "Joe: sage green" in str(response) assert "Hadley: red" in str(response) - assert len(chat.turns()) == 4 + assert len(chat.get_turns()) == 4 def assert_tools_sequential(chat_fun: ChatFun, total_calls: int, stream: bool = True): @@ -161,7 +161,7 @@ def popular_name(year: int): stream=stream, ) assert "Susan" in str(response) - assert len(chat.turns()) == total_calls + assert len(chat.get_turns()) == total_calls def assert_data_extraction(chat_fun: ChatFun): @@ -205,4 +205,4 @@ def assert_images_remote_error(chat_fun: ChatFun): with pytest.raises(Exception, match="Remote images aren't supported"): chat.chat("What's in this image?", image_remote) - assert len(chat.turns()) == 0 + assert len(chat.get_turns()) == 0 diff --git a/tests/test_chat.py b/tests/test_chat.py index b1bff2a2..6003ee4d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -31,7 +31,7 @@ def test_simple_streaming_chat(): result = "".join(chunks) rainbow_re = "^red *\norange *\nyellow *\ngreen *\nblue *\nindigo *\nviolet *\n?$" assert re.match(rainbow_re, result.lower()) - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert re.match(rainbow_re, turn.text.lower()) @@ -48,7 +48,7 @@ async def test_simple_streaming_chat_async(): result = "".join(chunks) rainbow_re = "^red *\norange *\nyellow *\ngreen *\nblue *\nindigo *\nviolet *\n?$" assert re.match(rainbow_re, result.lower()) - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert re.match(rainbow_re, turn.text.lower()) @@ -94,24 +94,24 @@ class Person(BaseModel): def test_last_turn_retrieval(): chat = ChatOpenAI() - assert chat.last_turn(role="user") is None - assert chat.last_turn(role="assistant") is None + assert chat.get_last_turn(role="user") is None + assert chat.get_last_turn(role="assistant") is None chat.chat("Hi") - user_turn = chat.last_turn(role="user") + user_turn = chat.get_last_turn(role="user") assert user_turn is not None and user_turn.role == "user" - turn = chat.last_turn(role="assistant") + turn = chat.get_last_turn(role="assistant") assert turn is not None and turn.role == "assistant" def test_system_prompt_retrieval(): chat1 = ChatOpenAI() assert chat1.system_prompt is None - assert chat1.last_turn(role="system") is None + assert chat1.get_last_turn(role="system") is None chat2 = ChatOpenAI(system_prompt="You are from New Zealand") assert chat2.system_prompt == "You are from New Zealand" - turn = chat2.last_turn(role="system") + turn = chat2.get_last_turn(role="system") assert turn is not None and turn.text == "You are from New Zealand" diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py index feb709fb..d6e847dd 100644 --- a/tests/test_provider_anthropic.py +++ b/tests/test_provider_anthropic.py @@ -21,7 +21,7 @@ def test_anthropic_simple_request(): system_prompt="Be as terse as possible; no punctuation", ) chat.chat("What is 1 + 1?") - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (26, 5) diff --git a/tests/test_provider_azure.py b/tests/test_provider_azure.py index d33284ef..045ffb35 100644 --- a/tests/test_provider_azure.py +++ b/tests/test_provider_azure.py @@ -18,7 +18,7 @@ def test_azure_simple_request(): response = chat.chat("What is 1 + 1?") assert "2" == response.get_content() - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (27, 1) @@ -34,6 +34,6 @@ async def test_azure_simple_request_async(): response = await chat.chat_async("What is 1 + 1?") assert "2" == await response.get_content() - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (27, 1) diff --git a/tests/test_provider_bedrock.py b/tests/test_provider_bedrock.py index 30e03dbb..0f257d6a 100644 --- a/tests/test_provider_bedrock.py +++ b/tests/test_provider_bedrock.py @@ -22,7 +22,7 @@ # system_prompt="Be as terse as possible; no punctuation", # ) # _ = str(chat.chat("What is 1 + 1?")) -# turn = chat.last_turn() +# turn = chat.get_last_turn() # assert turn is not None # assert turn.tokens == (26, 5) diff --git a/tests/test_provider_google.py b/tests/test_provider_google.py index 4b1c765d..f232b600 100644 --- a/tests/test_provider_google.py +++ b/tests/test_provider_google.py @@ -25,7 +25,7 @@ def test_google_simple_request(): system_prompt="Be as terse as possible; no punctuation", ) chat.chat("What is 1 + 1?") - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (17, 2) diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index 73cba99e..314dd6dc 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -19,7 +19,7 @@ def test_openai_simple_request(): system_prompt="Be as terse as possible; no punctuation", ) chat.chat("What is 1 + 1?") - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (27, 1) @@ -71,7 +71,7 @@ async def test_openai_logprobs(): async for x in await chat.stream_async("Hi", kwargs={"logprobs": True}): pieces.append(x) - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None logprobs = turn.json["choices"][0]["logprobs"]["content"] assert len(logprobs) == len(pieces) From efd39a3e2d7a4d10acee1e12d9244944167e8e8d Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 10 Dec 2024 17:46:15 -0600 Subject: [PATCH 2/3] Missed some renaming --- README.md | 2 +- docs/reference/Chat.qmd | 4 ++-- tests/test_provider_anthropic.py | 2 +- tests/test_provider_google.py | 3 +-- tests/test_provider_openai.py | 2 +- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 2df6c43f..122716e0 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ chat.chat("What is the capital of France?", echo="all") This shows important information like tool call results, finish reasons, and more. -If the problem isn't self-evident, you can also reach into the `.last_turn()`, which contains the full response object, with full details about the completion. +If the problem isn't self-evident, you can also reach into the `.get_last_turn()`, which contains the full response object, with full details about the completion.
diff --git a/docs/reference/Chat.qmd b/docs/reference/Chat.qmd index e7c4eea2..013826ed 100644 --- a/docs/reference/Chat.qmd +++ b/docs/reference/Chat.qmd @@ -32,7 +32,7 @@ You should generally not create this object yourself, but instead call | [console](#chatlas.Chat.console) | Enter a chat console to interact with the LLM. | | [extract_data](#chatlas.Chat.extract_data) | Extract structured data from the given input. | | [extract_data_async](#chatlas.Chat.extract_data_async) | Extract structured data from the given input asynchronously. | -| [last_turn](#chatlas.Chat.last_turn) | Get the last turn in the chat with a specific role. | +| [get_last_turn](#chatlas.Chat.get_last_turn) | Get the last turn in the chat with a specific role. | | [register_tool](#chatlas.Chat.register_tool) | Register a tool (function) with the chat. | | [set_turns](#chatlas.Chat.set_turns) | Set the turns of the chat. | | [tokens](#chatlas.Chat.tokens) | Get the tokens for each turn in the chat. | @@ -158,7 +158,7 @@ Extract structured data from the given input asynchronously. |--------|-----------------------------------------------------|---------------------| | | [dict](`dict`)\[[str](`str`), [Any](`typing.Any`)\] | The extracted data. | -### last_turn { #chatlas.Chat.last_turn } +### get_last_turn { #chatlas.Chat.get_last_turn } ```python Chat.get_last_turn(role='assistant') diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py index 2d5a4ac9..1a2a24b1 100644 --- a/tests/test_provider_anthropic.py +++ b/tests/test_provider_anthropic.py @@ -37,7 +37,7 @@ async def test_anthropic_simple_streaming_request(): async for x in foo: res.append(x) assert "2" in "".join(res) - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.finish_reason == "end_turn" diff --git a/tests/test_provider_google.py b/tests/test_provider_google.py index eb7718d4..dbf9bc53 100644 --- a/tests/test_provider_google.py +++ b/tests/test_provider_google.py @@ -2,7 +2,6 @@ import time import pytest - from chatlas import ChatGoogle from .conftest import ( @@ -41,7 +40,7 @@ async def test_google_simple_streaming_request(): async for x in await chat.stream_async("What is 1 + 1?"): res.append(x) assert "2" in "".join(res) - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.finish_reason == "STOP" diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index a9214743..05c16a94 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -34,7 +34,7 @@ async def test_openai_simple_streaming_request(): async for x in await chat.stream_async("What is 1 + 1?"): res.append(x) assert "2" in "".join(res) - turn = chat.last_turn() + turn = chat.get_last_turn() assert turn is not None assert turn.finish_reason == "stop" From 0287c60a9f4c53b984839c27814fd12288fac62c Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 10 Dec 2024 17:55:41 -0600 Subject: [PATCH 3/3] Eliminate warnings from tests --- tests/conftest.py | 2 +- tests/test_content_image.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c031a29a..313d916a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -178,7 +178,7 @@ def assert_images_inline(chat_fun: ChatFun, stream: bool = True): chat = chat_fun() response = chat.chat( "What's in this image?", - content_image_file(str(img_path)), + content_image_file(str(img_path), resize="low"), stream=stream, ) assert "red" in str(response).lower() diff --git a/tests/test_content_image.py b/tests/test_content_image.py index f245214e..484cdfdf 100644 --- a/tests/test_content_image.py +++ b/tests/test_content_image.py @@ -37,7 +37,7 @@ def test_can_create_image_from_path(tmp_path): path = tmp_path / "test.png" img.save(path) - obj = content_image_file(str(path)) + obj = content_image_file(str(path), resize="low") assert isinstance(obj, ContentImageInline) @@ -65,7 +65,8 @@ def test_image_resizing(tmp_path): content_image_file(str(tmp_path / "test.txt")) # Test valid resize options - assert content_image_file(str(img_path)) is not None + with pytest.warns(RuntimeWarning): + assert content_image_file(str(img_path)) is not None assert content_image_file(str(img_path), resize="low") is not None assert content_image_file(str(img_path), resize="high") is not None assert content_image_file(str(img_path), resize="none") is not None