Skip to content

Commit c2c2674

Browse files
authored
enable session affinity for cache optimization (#89)
* enable session affinity for cache optimization * fixing CI * fix bug on logger.warning
1 parent 5f3226b commit c2c2674

File tree

3 files changed

+80
-88
lines changed

3 files changed

+80
-88
lines changed

stagehand/api.py

Lines changed: 71 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import json
22
from typing import Any
33

4-
import httpx
5-
64
from .utils import convert_dict_keys_to_camel_case
75

86
__all__ = ["_create_session", "_execute"]
@@ -73,21 +71,20 @@ async def _create_session(self):
7371
"x-language": "python",
7472
}
7573

76-
client = httpx.AsyncClient(timeout=self.timeout_settings)
77-
async with client:
78-
resp = await client.post(
79-
f"{self.api_url}/sessions/start",
80-
json=payload,
81-
headers=headers,
82-
)
83-
if resp.status_code != 200:
84-
raise RuntimeError(f"Failed to create session: {resp.text}")
85-
data = resp.json()
86-
self.logger.debug(f"Session created: {data}")
87-
if not data.get("success") or "sessionId" not in data.get("data", {}):
88-
raise RuntimeError(f"Invalid response format: {resp.text}")
74+
# async with self._client:
75+
resp = await self._client.post(
76+
f"{self.api_url}/sessions/start",
77+
json=payload,
78+
headers=headers,
79+
)
80+
if resp.status_code != 200:
81+
raise RuntimeError(f"Failed to create session: {resp.text}")
82+
data = resp.json()
83+
self.logger.debug(f"Session created: {data}")
84+
if not data.get("success") or "sessionId" not in data.get("data", {}):
85+
raise RuntimeError(f"Invalid response format: {resp.text}")
8986

90-
self.session_id = data["data"]["sessionId"]
87+
self.session_id = data["data"]["sessionId"]
9188

9289

9390
async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
@@ -109,65 +106,61 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
109106
# Convert snake_case keys to camelCase for the API
110107
modified_payload = convert_dict_keys_to_camel_case(payload)
111108

112-
client = httpx.AsyncClient(timeout=self.timeout_settings)
113-
114-
async with client:
115-
try:
116-
# Always use streaming for consistent log handling
117-
async with client.stream(
118-
"POST",
119-
f"{self.api_url}/sessions/{self.session_id}/{method}",
120-
json=modified_payload,
121-
headers=headers,
122-
) as response:
123-
if response.status_code != 200:
124-
error_text = await response.aread()
125-
error_message = error_text.decode("utf-8")
126-
self.logger.error(
127-
f"[HTTP ERROR] Status {response.status_code}: {error_message}"
128-
)
129-
raise RuntimeError(
130-
f"Request failed with status {response.status_code}: {error_message}"
131-
)
132-
result = None
133-
134-
async for line in response.aiter_lines():
135-
# Skip empty lines
136-
if not line.strip():
137-
continue
138-
139-
try:
140-
# Handle SSE-style messages that start with "data: "
141-
if line.startswith("data: "):
142-
line = line[len("data: ") :]
143-
144-
message = json.loads(line)
145-
# Handle different message types
146-
msg_type = message.get("type")
147-
148-
if msg_type == "system":
149-
status = message.get("data", {}).get("status")
150-
if status == "error":
151-
error_msg = message.get("data", {}).get(
152-
"error", "Unknown error"
153-
)
154-
self.logger.error(f"[ERROR] {error_msg}")
155-
raise RuntimeError(
156-
f"Server returned error: {error_msg}"
157-
)
158-
elif status == "finished":
159-
result = message.get("data", {}).get("result")
160-
elif msg_type == "log":
161-
# Process log message using _handle_log
162-
await self._handle_log(message)
163-
else:
164-
# Log any other message types
165-
self.logger.debug(f"[UNKNOWN] Message type: {msg_type}")
166-
except json.JSONDecodeError:
167-
self.logger.warning(f"Could not parse line as JSON: {line}")
168-
169-
# Return the final result
170-
return result
171-
except Exception as e:
172-
self.logger.error(f"[EXCEPTION] {str(e)}")
173-
raise
109+
# async with self._client:
110+
try:
111+
# Always use streaming for consistent log handling
112+
async with self._client.stream(
113+
"POST",
114+
f"{self.api_url}/sessions/{self.session_id}/{method}",
115+
json=modified_payload,
116+
headers=headers,
117+
) as response:
118+
if response.status_code != 200:
119+
error_text = await response.aread()
120+
error_message = error_text.decode("utf-8")
121+
self.logger.error(
122+
f"[HTTP ERROR] Status {response.status_code}: {error_message}"
123+
)
124+
raise RuntimeError(
125+
f"Request failed with status {response.status_code}: {error_message}"
126+
)
127+
result = None
128+
129+
async for line in response.aiter_lines():
130+
# Skip empty lines
131+
if not line.strip():
132+
continue
133+
134+
try:
135+
# Handle SSE-style messages that start with "data: "
136+
if line.startswith("data: "):
137+
line = line[len("data: ") :]
138+
139+
message = json.loads(line)
140+
# Handle different message types
141+
msg_type = message.get("type")
142+
143+
if msg_type == "system":
144+
status = message.get("data", {}).get("status")
145+
if status == "error":
146+
error_msg = message.get("data", {}).get(
147+
"error", "Unknown error"
148+
)
149+
self.logger.error(f"[ERROR] {error_msg}")
150+
raise RuntimeError(f"Server returned error: {error_msg}")
151+
elif status == "finished":
152+
result = message.get("data", {}).get("result")
153+
elif msg_type == "log":
154+
# Process log message using _handle_log
155+
await self._handle_log(message)
156+
else:
157+
# Log any other message types
158+
self.logger.debug(f"[UNKNOWN] Message type: {msg_type}")
159+
except json.JSONDecodeError:
160+
self.logger.error(f"Could not parse line as JSON: {line}")
161+
162+
# Return the final result
163+
return result
164+
except Exception as e:
165+
self.logger.error(f"[EXCEPTION] {str(e)}")
166+
raise

stagehand/main.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
)
145145
if not self.model_api_key:
146146
# Model API key needed if Stagehand server creates the session
147-
self.logger.warning(
147+
self.logger.info(
148148
"model_api_key is recommended when creating a new BROWSERBASE session to configure the Stagehand server's LLM."
149149
)
150150
elif self.session_id:
@@ -161,9 +161,7 @@ def __init__(
161161
# Register signal handlers for graceful shutdown
162162
self._register_signal_handlers()
163163

164-
self._client: Optional[httpx.AsyncClient] = (
165-
None # Used for server communication in BROWSERBASE
166-
)
164+
self._client = httpx.AsyncClient(timeout=self.timeout_settings)
167165

168166
self._playwright: Optional[Playwright] = None
169167
self._browser = None
@@ -388,9 +386,6 @@ async def init(self):
388386
self._playwright = await async_playwright().start()
389387

390388
if self.env == "BROWSERBASE":
391-
if not self._client:
392-
self._client = httpx.AsyncClient(timeout=self.timeout_settings)
393-
394389
# Create session if we don't have one
395390
if not self.session_id:
396391
await self._create_session() # Uses self._client and api_url

tests/unit/llm/test_llm_integration.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from stagehand.llm.client import LLMClient
88
from tests.mocks.mock_llm import MockLLMClient, MockLLMResponse
9+
from stagehand.logging import StagehandLogger
910

1011

1112
class TestLLMClientInitialization:
@@ -15,7 +16,8 @@ def test_llm_client_creation_with_openai(self):
1516
"""Test LLM client creation with OpenAI provider"""
1617
client = LLMClient(
1718
api_key="test-openai-key",
18-
default_model="gpt-4o"
19+
default_model="gpt-4o",
20+
stagehand_logger=StagehandLogger(),
1921
)
2022

2123
assert client.default_model == "gpt-4o"
@@ -25,7 +27,8 @@ def test_llm_client_creation_with_anthropic(self):
2527
"""Test LLM client creation with Anthropic provider"""
2628
client = LLMClient(
2729
api_key="test-anthropic-key",
28-
default_model="claude-3-sonnet"
30+
default_model="claude-3-sonnet",
31+
stagehand_logger=StagehandLogger(),
2932
)
3033

3134
assert client.default_model == "claude-3-sonnet"
@@ -35,7 +38,8 @@ def test_llm_client_with_custom_options(self):
3538
"""Test LLM client with custom configuration options"""
3639
client = LLMClient(
3740
api_key="test-key",
38-
default_model="gpt-4o-mini"
41+
default_model="gpt-4o-mini",
42+
stagehand_logger=StagehandLogger(),
3943
)
4044

4145
assert client.default_model == "gpt-4o-mini"

0 commit comments

Comments
 (0)