Skip to content

Commit 79d7809

Browse files
Support prompt input streaming
1 parent a4e14f1 commit 79d7809

File tree

4 files changed

+501
-5
lines changed

4 files changed

+501
-5
lines changed

pkg/claude/claude.go

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ const (
2929
StreamJSONOutput OutputFormat = "stream-json"
3030
)
3131

32+
// InputFormat defines the input format for Claude Code requests
33+
type InputFormat string
34+
35+
const (
36+
// TextInput sends plain text input (default)
37+
TextInput InputFormat = "text"
38+
// StreamJSONInput sends streaming JSON input for multiple prompts
39+
StreamJSONInput InputFormat = "stream-json"
40+
)
41+
3242
// ClaudeClient is the main client for interacting with Claude Code
3343
type ClaudeClient struct {
3444
// BinPath is the path to the Claude Code binary
@@ -41,6 +51,8 @@ type ClaudeClient struct {
4151
type RunOptions struct {
4252
// Format specifies the output format (text, json, stream-json)
4353
Format OutputFormat
54+
// InputFormat specifies the input format (text, stream-json)
55+
InputFormat InputFormat
4456
// SystemPrompt overrides the default system prompt
4557
SystemPrompt string
4658
// AppendPrompt appends to the default system prompt
@@ -126,6 +138,18 @@ type QueryOptions struct {
126138
BufferConfig *buffer.Config `json:"-"`
127139
}
128140

141+
// StreamInputMessage represents a message for streaming input to Claude Code
142+
type StreamInputMessage struct {
143+
Type string `json:"type"`
144+
Message StreamInputContent `json:"message"`
145+
}
146+
147+
// StreamInputContent represents the content of a streaming input message
148+
type StreamInputContent struct {
149+
Role string `json:"role"`
150+
Content string `json:"content"`
151+
}
152+
129153
// Message represents a message from Claude Code in streaming mode
130154
// Aligned with Python SDK message structure
131155
type Message struct {
@@ -321,8 +345,23 @@ func (c *ClaudeClient) RunPromptCtx(ctx context.Context, prompt string, opts *Ru
321345
}, nil
322346
}
323347

324-
// StreamPrompt executes a prompt with Claude Code and streams the results through a channel
348+
// StreamPrompt executes a prompt with Claude Code and streams the results thro ugh a channel
325349
func (c *ClaudeClient) StreamPrompt(ctx context.Context, prompt string, opts *RunOptions) (<-chan Message, <-chan error) {
350+
// Create a channel for a single prompt
351+
promptCh := make(chan string, 1)
352+
promptCh <- prompt
353+
close(promptCh)
354+
355+
// Use the multi-prompt streaming function
356+
return c.StreamPromptsToSession(ctx, promptCh, opts)
357+
}
358+
359+
// StreamPromptsToSession starts a Claude Code session and streams prompts to it continuously
360+
func (c *ClaudeClient) StreamPromptsToSession(
361+
ctx context.Context,
362+
promptCh <-chan string,
363+
opts *RunOptions,
364+
) (<-chan Message, <-chan error) {
326365
messageCh := make(chan Message)
327366
errCh := make(chan error, 1)
328367

@@ -333,11 +372,13 @@ func (c *ClaudeClient) StreamPrompt(ctx context.Context, prompt string, opts *Ru
333372
// Force stream-json format for streaming
334373
streamOpts := *opts
335374
streamOpts.Format = StreamJSONOutput
375+
streamOpts.InputFormat = StreamJSONInput
336376

337377
// Claude CLI requires --verbose when using --output-format=stream-json with --print
338378
streamOpts.Verbose = true
339379

340-
args := BuildArgs(prompt, &streamOpts)
380+
// Remove the initial prompt since we'll stream through stdin
381+
args := BuildArgs("", &streamOpts)
341382

342383
go func() {
343384
defer close(messageCh)
@@ -346,6 +387,12 @@ func (c *ClaudeClient) StreamPrompt(ctx context.Context, prompt string, opts *Ru
346387
// Create a custom command that supports context
347388
cmd := execCommand(ctx, c.BinPath, args...)
348389

390+
stdin, err := cmd.StdinPipe()
391+
if err != nil {
392+
errCh <- fmt.Errorf("failed to get stdin pipe: %w", err)
393+
return
394+
}
395+
349396
stdout, err := cmd.StdoutPipe()
350397
if err != nil {
351398
errCh <- fmt.Errorf("failed to get stdout pipe: %w", err)
@@ -376,6 +423,67 @@ func (c *ClaudeClient) StreamPrompt(ctx context.Context, prompt string, opts *Ru
376423
return
377424
}
378425

426+
// Channel to signal when command is ready to receive input
427+
cmdReady := make(chan struct{})
428+
429+
// Start a goroutine to handle input prompts
430+
go func() {
431+
defer stdin.Close()
432+
433+
// Wait for command to be ready before processing prompts
434+
select {
435+
case <-cmdReady:
436+
// Command is ready, proceed with prompt processing
437+
case <-ctx.Done():
438+
return
439+
}
440+
441+
for {
442+
select {
443+
case prompt, ok := <-promptCh:
444+
if !ok {
445+
// Prompt channel closed, close stdin
446+
return
447+
}
448+
449+
// Create JSON message for streaming input
450+
streamMsg := StreamInputMessage{
451+
Type: "user",
452+
Message: StreamInputContent{
453+
Role: "user",
454+
Content: prompt,
455+
},
456+
}
457+
458+
// Encode as JSON and send
459+
jsonData, err := json.Marshal(streamMsg)
460+
if err != nil {
461+
errCh <- fmt.Errorf("failed to marshal stream input message: %w", err)
462+
return
463+
}
464+
465+
if _, err := fmt.Fprintln(stdin, string(jsonData)); err != nil {
466+
errCh <- fmt.Errorf("failed to write JSON prompt to stdin: %w", err)
467+
return
468+
}
469+
case <-ctx.Done():
470+
return
471+
}
472+
}
473+
}()
474+
475+
// Give the command a moment to initialize before signaling ready
476+
// This prevents the race condition where we write to stdin before
477+
// the Claude process is ready to read from it
478+
go func() {
479+
select {
480+
case <-time.After(100 * time.Millisecond):
481+
close(cmdReady)
482+
case <-ctx.Done():
483+
return
484+
}
485+
}()
486+
379487
// Use buffered reader with configurable buffer size instead of scanner
380488
reader := bufio.NewReaderSize(stdout, int(bufferConfig.MaxStdoutSize/1000)) // Use reasonable buffer size
381489

@@ -642,6 +750,10 @@ func BuildArgs(prompt string, opts *RunOptions) []string {
642750
args = append(args, "--output-format", string(opts.Format))
643751
}
644752

753+
if opts.InputFormat != "" {
754+
args = append(args, "--input-format", string(opts.InputFormat))
755+
}
756+
645757
if opts.SystemPrompt != "" {
646758
args = append(args, "--system-prompt", opts.SystemPrompt)
647759
}

pkg/claude/input_format_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package claude
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestInputFormat_Constants(t *testing.T) {
9+
tests := []struct {
10+
format InputFormat
11+
expected string
12+
}{
13+
{TextInput, "text"},
14+
{StreamJSONInput, "stream-json"},
15+
}
16+
17+
for _, test := range tests {
18+
if string(test.format) != test.expected {
19+
t.Errorf("Expected %s, got %s", test.expected, string(test.format))
20+
}
21+
}
22+
}
23+
24+
func TestBuildArgs_InputFormat(t *testing.T) {
25+
tests := []struct {
26+
name string
27+
opts *RunOptions
28+
expectedArg string
29+
}{
30+
{
31+
name: "No input format specified",
32+
opts: &RunOptions{},
33+
expectedArg: "",
34+
},
35+
{
36+
name: "Text input format",
37+
opts: &RunOptions{InputFormat: TextInput},
38+
expectedArg: "--input-format text",
39+
},
40+
{
41+
name: "Stream JSON input format",
42+
opts: &RunOptions{InputFormat: StreamJSONInput},
43+
expectedArg: "--input-format stream-json",
44+
},
45+
}
46+
47+
for _, test := range tests {
48+
t.Run(test.name, func(t *testing.T) {
49+
args := BuildArgs("test prompt", test.opts)
50+
argsStr := strings.Join(args, " ")
51+
52+
if test.expectedArg == "" {
53+
if strings.Contains(argsStr, "--input-format") {
54+
t.Errorf("Expected no --input-format flag, but found one in: %s", argsStr)
55+
}
56+
} else {
57+
if !strings.Contains(argsStr, test.expectedArg) {
58+
t.Errorf("Expected to find '%s' in args: %s", test.expectedArg, argsStr)
59+
}
60+
}
61+
})
62+
}
63+
}
64+
65+
func TestStreamInputMessage_JSONMarshaling(t *testing.T) {
66+
msg := StreamInputMessage{
67+
Type: "user",
68+
Message: StreamInputContent{
69+
Role: "user",
70+
Content: "Test message",
71+
},
72+
}
73+
74+
// Test the struct fields are correctly defined
75+
if msg.Type != "user" {
76+
t.Errorf("Expected Type 'user', got %s", msg.Type)
77+
}
78+
if msg.Message.Role != "user" {
79+
t.Errorf("Expected Role 'user', got %s", msg.Message.Role)
80+
}
81+
if msg.Message.Content != "Test message" {
82+
t.Errorf("Expected Content 'Test message', got %s", msg.Message.Content)
83+
}
84+
}

pkg/claude/query_test.go

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package claude
33
import (
44
"context"
55
"encoding/json"
6+
"os"
7+
"os/exec"
8+
"path/filepath"
69
"testing"
710
"time"
811

@@ -13,8 +16,9 @@ func TestQuery_PythonSDKAlignment(t *testing.T) {
1316
if testing.Short() {
1417
t.Skip("Skipping integration test in short mode")
1518
}
16-
17-
client := &ClaudeClient{BinPath: "echo"}
19+
20+
skipIfNoClaudeCLI(t)
21+
client := newTestClient(t)
1822
ctx := context.Background()
1923

2024
// Test basic Query method with QueryOptions
@@ -50,7 +54,8 @@ func TestQuerySync_PythonSDKAlignment(t *testing.T) {
5054
t.Skip("Skipping integration test in short mode")
5155
}
5256

53-
client := &ClaudeClient{BinPath: "echo"}
57+
skipIfNoClaudeCLI(t)
58+
client := newTestClient(t)
5459
ctx := context.Background()
5560

5661
// Test synchronous Query method
@@ -198,4 +203,68 @@ func contains(slice []string, item string) bool {
198203
}
199204
}
200205
return false
206+
}
207+
208+
// Local test helper functions to avoid import cycle with test/utils
209+
210+
// skipIfNoClaudeCLI skips the test if Claude Code CLI is not available
211+
func skipIfNoClaudeCLI(t *testing.T) {
212+
if os.Getenv("USE_MOCK_SERVER") == "1" {
213+
return // Mock server tests don't need Claude CLI
214+
}
215+
216+
claudePath := getTestClaudePath(t)
217+
if _, err := exec.LookPath(claudePath); err != nil {
218+
t.Skipf("Skipping test: Claude Code CLI not found at '%s'. Install Claude Code CLI or use mock server with USE_MOCK_SERVER=1", claudePath)
219+
}
220+
221+
// Test if Claude CLI is working (it will handle auth automatically)
222+
cmd := exec.Command(claudePath, "--help")
223+
if err := cmd.Run(); err != nil {
224+
t.Skip("Skipping test: Claude Code CLI not working. Please ensure it's properly installed.")
225+
}
226+
}
227+
228+
// newTestClient creates a Claude client for testing
229+
func newTestClient(t *testing.T) *ClaudeClient {
230+
return NewClient(getTestClaudePath(t))
231+
}
232+
233+
// getTestClaudePath returns the path to Claude CLI for testing
234+
func getTestClaudePath(t *testing.T) string {
235+
if path := os.Getenv("CLAUDE_CODE_PATH"); path != "" {
236+
return path
237+
}
238+
239+
// Check if mock server is preferred
240+
if os.Getenv("USE_MOCK_SERVER") == "1" {
241+
return createMockClaudeScript(t)
242+
}
243+
244+
return "claude" // Default Claude Code CLI
245+
}
246+
247+
// createMockClaudeScript creates a shell script that returns valid JSON
248+
func createMockClaudeScript(t *testing.T) string {
249+
tempDir := t.TempDir()
250+
mockScript := filepath.Join(tempDir, "mock-claude.sh")
251+
252+
content := `#!/bin/bash
253+
# Mock Claude CLI that returns appropriate JSON response based on format
254+
if [[ "$*" == *"--format json"* ]]; then
255+
# For regular JSON format (QuerySync)
256+
echo '{"result": "Mock response", "sessionId": "test-session", "costUSD": 0.001, "durationMS": 100}'
257+
else
258+
# For streaming format or default (Query)
259+
echo '{"type": "system", "subtype": "init", "content": "Initializing"}'
260+
echo '{"type": "assistant", "content": "Mock streaming response"}'
261+
echo '{"type": "result", "result": "Mock response", "sessionId": "test-session", "costUSD": 0.001, "durationMS": 100}'
262+
fi
263+
`
264+
265+
if err := os.WriteFile(mockScript, []byte(content), 0755); err != nil {
266+
t.Fatalf("Failed to create mock script: %v", err)
267+
}
268+
269+
return mockScript
201270
}

0 commit comments

Comments
 (0)