-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclaude_stream.py
More file actions
351 lines (308 loc) · 16.8 KB
/
claude_stream.py
File metadata and controls
351 lines (308 loc) · 16.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import json
import logging
import importlib.util
import uuid
from pathlib import Path
from typing import AsyncGenerator, Optional, Dict, Any, List, Set
import tiktoken
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------------------
try:
# cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
ENCODING = tiktoken.get_encoding("cl100k_base")
except Exception:
ENCODING = None
THINKING_START_TAG = "<thinking>"
THINKING_END_TAG = "</thinking>"
def _pending_tag_suffix(buffer: str, tag: str) -> int:
"""Length of the suffix of buffer that matches the prefix of tag (for partial matches)."""
if not buffer or not tag:
return 0
max_len = min(len(buffer), len(tag) - 1)
for length in range(max_len, 0, -1):
if buffer[-length:] == tag[:length]:
return length
return 0
def count_tokens(text: str) -> int:
"""Counts tokens with tiktoken."""
if not text or not ENCODING:
return 0
return len(ENCODING.encode(text))
# ------------------------------------------------------------------------------
# Dynamic Loader
# ------------------------------------------------------------------------------
def _load_claude_parser():
"""Dynamically load claude_parser module."""
base_dir = Path(__file__).resolve().parent
spec = importlib.util.spec_from_file_location("v2_claude_parser", str(base_dir / "claude_parser.py"))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
try:
_parser = _load_claude_parser()
build_message_start = _parser.build_message_start
build_content_block_start = _parser.build_content_block_start
build_content_block_delta = _parser.build_content_block_delta
build_content_block_stop = _parser.build_content_block_stop
build_ping = _parser.build_ping
build_message_stop = _parser.build_message_stop
build_tool_use_start = _parser.build_tool_use_start
build_tool_use_input_delta = _parser.build_tool_use_input_delta
except Exception as e:
logger.error(f"Failed to load claude_parser: {e}")
# Fallback definitions
def build_message_start(*args, **kwargs): return ""
def build_content_block_start(*args, **kwargs): return ""
def build_content_block_delta(*args, **kwargs): return ""
def build_content_block_stop(*args, **kwargs): return ""
def build_ping(*args, **kwargs): return ""
def build_message_stop(*args, **kwargs): return ""
def build_tool_use_start(*args, **kwargs): return ""
def build_tool_use_input_delta(*args, **kwargs): return ""
class ClaudeStreamHandler:
def __init__(self, model: str, input_tokens: int = 0, conversation_id: Optional[str] = None):
self.model = model
self.input_tokens = input_tokens
self.response_buffer: List[str] = []
self.content_block_index: int = -1
self.content_block_started: bool = False
self.content_block_start_sent: bool = False
self.content_block_stop_sent: bool = False
self.message_start_sent: bool = False
self.conversation_id: Optional[str] = conversation_id
# Tool use state
self.current_tool_use: Optional[Dict[str, Any]] = None
self.tool_input_buffer: List[str] = []
self.tool_use_id: Optional[str] = None
self.tool_name: Optional[str] = None
self._processed_tool_use_ids: Set[str] = set()
self.all_tool_inputs: List[str] = []
# Think tag state
self.in_think_block: bool = False
self.think_buffer: str = ""
self.pending_start_tag_chars: int = 0
# Response termination flag
self.response_ended: bool = False
async def handle_event(self, event_type: str, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
"""Process a single Amazon Q event and yield Claude SSE events."""
# Early return if response has already ended
if self.response_ended:
return
# 1. Message Start (initial-response)
if event_type == "initial-response":
if not self.message_start_sent:
# Use conversation_id from payload if available, otherwise use the one passed to constructor
conv_id = payload.get('conversationId') or self.conversation_id or str(uuid.uuid4())
self.conversation_id = conv_id
yield build_message_start(conv_id, self.model, self.input_tokens)
self.message_start_sent = True
yield build_ping()
# 2. Content Block Delta (assistantResponseEvent)
elif event_type == "assistantResponseEvent":
content = payload.get("content", "")
# Close any open tool use block
if self.current_tool_use and not self.content_block_stop_sent:
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
self.current_tool_use = None
# Process content with think tag detection
if content:
self.think_buffer += content
while self.think_buffer:
if self.pending_start_tag_chars > 0:
if len(self.think_buffer) < self.pending_start_tag_chars:
self.pending_start_tag_chars -= len(self.think_buffer)
self.think_buffer = ""
break
else:
self.think_buffer = self.think_buffer[self.pending_start_tag_chars:]
self.pending_start_tag_chars = 0
if not self.think_buffer:
break
continue
if not self.in_think_block:
think_start = self.think_buffer.find(THINKING_START_TAG)
if think_start == -1:
pending = _pending_tag_suffix(self.think_buffer, THINKING_START_TAG)
if pending == len(self.think_buffer) and pending > 0:
if self.content_block_start_sent:
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
self.content_block_start_sent = False
self.content_block_index += 1
yield build_content_block_start(self.content_block_index, "thinking")
self.content_block_start_sent = True
self.content_block_started = True
self.content_block_stop_sent = False
self.in_think_block = True
self.pending_start_tag_chars = len(THINKING_START_TAG) - pending
self.think_buffer = ""
break
emit_len = len(self.think_buffer) - pending
if emit_len <= 0:
break
text_chunk = self.think_buffer[:emit_len]
if text_chunk:
if not self.content_block_start_sent:
self.content_block_index += 1
yield build_content_block_start(self.content_block_index, "text")
self.content_block_start_sent = True
self.content_block_started = True
self.content_block_stop_sent = False
self.response_buffer.append(text_chunk)
yield build_content_block_delta(self.content_block_index, text_chunk)
self.think_buffer = self.think_buffer[emit_len:]
else:
before_text = self.think_buffer[:think_start]
if before_text:
if not self.content_block_start_sent:
self.content_block_index += 1
yield build_content_block_start(self.content_block_index, "text")
self.content_block_start_sent = True
self.content_block_started = True
self.content_block_stop_sent = False
self.response_buffer.append(before_text)
yield build_content_block_delta(self.content_block_index, before_text)
self.think_buffer = self.think_buffer[think_start + len(THINKING_START_TAG):]
if self.content_block_start_sent:
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
self.content_block_start_sent = False
self.content_block_index += 1
yield build_content_block_start(self.content_block_index, "thinking")
self.content_block_start_sent = True
self.content_block_started = True
self.content_block_stop_sent = False
self.in_think_block = True
self.pending_start_tag_chars = 0
else:
think_end = self.think_buffer.find(THINKING_END_TAG)
if think_end == -1:
pending = _pending_tag_suffix(self.think_buffer, THINKING_END_TAG)
emit_len = len(self.think_buffer) - pending
if emit_len <= 0:
break
thinking_chunk = self.think_buffer[:emit_len]
if thinking_chunk:
yield build_content_block_delta(
self.content_block_index,
thinking_chunk,
delta_type="thinking_delta",
field_name="thinking"
)
self.think_buffer = self.think_buffer[emit_len:]
else:
thinking_chunk = self.think_buffer[:think_end]
if thinking_chunk:
yield build_content_block_delta(
self.content_block_index,
thinking_chunk,
delta_type="thinking_delta",
field_name="thinking"
)
self.think_buffer = self.think_buffer[think_end + len(THINKING_END_TAG):]
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
self.content_block_start_sent = False
self.in_think_block = False
# 3. Tool Use (toolUseEvent)
elif event_type == "toolUseEvent":
tool_use_id = payload.get("toolUseId")
tool_name = payload.get("name")
tool_input = payload.get("input", {})
is_stop = payload.get("stop", False)
# Deduplication: skip if this tool_use_id was already processed and no tool is active
# (allows input deltas to pass through when current_tool_use is set)
if tool_use_id and tool_use_id in self._processed_tool_use_ids and not self.current_tool_use:
logger.warning(f"Detected duplicate tool use event, toolUseId={tool_use_id}, skipping")
return
# Start new tool use
if tool_use_id and tool_name and not self.current_tool_use:
# Close previous text block if open
if self.content_block_start_sent and not self.content_block_stop_sent:
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
self._processed_tool_use_ids.add(tool_use_id)
self.content_block_index += 1
yield build_tool_use_start(self.content_block_index, tool_use_id, tool_name)
self.content_block_started = True
self.current_tool_use = {"toolUseId": tool_use_id, "name": tool_name}
self.tool_use_id = tool_use_id
self.tool_name = tool_name
self.tool_input_buffer = []
self.content_block_stop_sent = False
self.content_block_start_sent = True
# Accumulate input
if self.current_tool_use and tool_input:
fragment = ""
if isinstance(tool_input, str):
fragment = tool_input
else:
fragment = json.dumps(tool_input, ensure_ascii=False)
self.tool_input_buffer.append(fragment)
yield build_tool_use_input_delta(self.content_block_index, fragment)
# Stop tool use
if is_stop and self.current_tool_use:
full_input = "".join(self.tool_input_buffer)
self.all_tool_inputs.append(full_input)
yield build_content_block_stop(self.content_block_index)
# Reset state to allow next content block
self.content_block_stop_sent = False # Reset to False to allow next block
self.content_block_started = False
self.content_block_start_sent = False # Important: reset start flag for next block
self.current_tool_use = None
self.tool_use_id = None
self.tool_name = None
self.tool_input_buffer = []
# 4. Assistant Response End (assistantResponseEnd)
elif event_type == "assistantResponseEnd":
# Close any open block
if self.content_block_started and not self.content_block_stop_sent:
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
# Mark as finished to prevent processing further events
self.response_ended = True
# Immediately send message_stop (instead of waiting for finish())
full_text = "".join(self.response_buffer)
full_tool_input = "".join(self.all_tool_inputs)
output_tokens = count_tokens(full_text) + count_tokens(full_tool_input)
yield build_message_stop(self.input_tokens, output_tokens, "end_turn")
async def finish(self) -> AsyncGenerator[str, None]:
"""Send final events."""
# Skip if response already ended (message_stop already sent)
if self.response_ended:
return
# Flush any remaining think_buffer content
if self.think_buffer:
if self.in_think_block:
# Emit remaining thinking content
yield build_content_block_delta(
self.content_block_index,
self.think_buffer,
delta_type="thinking_delta",
field_name="thinking"
)
else:
# Emit remaining text content
if not self.content_block_start_sent:
self.content_block_index += 1
yield build_content_block_start(self.content_block_index, "text")
self.content_block_start_sent = True
self.content_block_started = True
self.content_block_stop_sent = False
self.response_buffer.append(self.think_buffer)
yield build_content_block_delta(self.content_block_index, self.think_buffer)
self.think_buffer = ""
# Ensure last block is closed
if self.content_block_started and not self.content_block_stop_sent:
yield build_content_block_stop(self.content_block_index)
self.content_block_stop_sent = True
# Calculate output tokens (approximate)
full_text = "".join(self.response_buffer)
full_tool_input = "".join(self.all_tool_inputs)
# Simple approximation: 4 chars per token
# output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
output_tokens = count_tokens(full_text) + count_tokens(full_tool_input)
yield build_message_stop(self.input_tokens, output_tokens, "end_turn")