Skip to content

Commit 7556ebc

Browse files
shukladivyanshcopybara-github
authored andcommitted
feat: Allow max tokens to be customizable in Claude
PiperOrigin-RevId: 789901925
1 parent 2bb2041 commit 7556ebc

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

src/google/adk/models/anthropic_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646

4747
logger = logging.getLogger("google_adk." + __name__)
4848

49-
MAX_TOKEN = 8192
50-
5149

5250
class ClaudeRequest(BaseModel):
5351
system_instruction: str
@@ -245,9 +243,11 @@ class Claude(BaseLlm):
245243
246244
Attributes:
247245
model: The name of the Claude model.
246+
max_tokens: The maximum number of tokens to generate.
248247
"""
249248

250249
model: str = "claude-3-5-sonnet-v2@20241022"
250+
max_tokens: int = 8192
251251

252252
@staticmethod
253253
@override
@@ -284,7 +284,7 @@ async def generate_content_async(
284284
messages=messages,
285285
tools=tools,
286286
tool_choice=tool_choice,
287-
max_tokens=MAX_TOKEN,
287+
max_tokens=self.max_tokens,
288288
)
289289
yield message_to_generate_content_response(message)
290290

tests/unittests/models/test_anthropic_llm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,32 @@ async def mock_coro():
122122
assert len(responses) == 1
123123
assert isinstance(responses[0], LlmResponse)
124124
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_generate_content_async_with_max_tokens(
129+
llm_request, generate_content_response, generate_llm_response
130+
):
131+
claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096)
132+
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
133+
with mock.patch.object(
134+
anthropic_llm,
135+
"message_to_generate_content_response",
136+
return_value=generate_llm_response,
137+
):
138+
# Create a mock coroutine that returns the generate_content_response.
139+
async def mock_coro():
140+
return generate_content_response
141+
142+
# Assign the coroutine to the mocked method
143+
mock_client.messages.create.return_value = mock_coro()
144+
145+
_ = [
146+
resp
147+
async for resp in claude_llm.generate_content_async(
148+
llm_request, stream=False
149+
)
150+
]
151+
mock_client.messages.create.assert_called_once()
152+
_, kwargs = mock_client.messages.create.call_args
153+
assert kwargs["max_tokens"] == 4096

0 commit comments

Comments
 (0)