Skip to content
Open
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ test-race:
CGO_ENABLED=1 go test -count=1 -race ./...

coverage:
go test -coverprofile=coverage.out ./...
go test -coverprofile=coverage.out -coverpkg=./... ./...
go tool cover -func=coverage.out | tail -n 1

coverage-html:
Expand Down
27 changes: 23 additions & 4 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -1792,16 +1793,31 @@ const mockToolName = "coder_list_workspaces"

// callAccumulator tracks all tool invocations by name and each instance's arguments.
type callAccumulator struct {
calls map[string][]any
callsMu sync.Mutex
calls map[string][]any
callsMu sync.Mutex
toolErrors map[string]string
}

func newCallAccumulator() *callAccumulator {
return &callAccumulator{
calls: make(map[string][]any),
calls: make(map[string][]any),
toolErrors: make(map[string]string),
}
}

func (a *callAccumulator) setToolError(tool string, errMsg string) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
a.toolErrors[tool] = errMsg
}

func (a *callAccumulator) getToolError(tool string) (string, bool) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
errMsg, ok := a.toolErrors[tool]
return errMsg, ok
}

func (a *callAccumulator) addCall(tool string, args any) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
Expand Down Expand Up @@ -1831,12 +1847,15 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
// Accumulate tool calls & their arguments.
acc := newCallAccumulator()

for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
return nil, errors.New(errMsg)
}
return mcplib.NewToolResultText("mock"), nil
})
}
Expand Down
10 changes: 8 additions & 2 deletions fixtures/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ var (
//go:embed openai/responses/blocking/simple.txtar
OaiResponsesBlockingSimple []byte

//go:embed openai/responses/blocking/builtin_tool.txtar
OaiResponsesBlockingBuiltinTool []byte
//go:embed openai/responses/blocking/single_builtin_tool.txtar
OaiResponsesBlockingSingleBuiltinTool []byte

//go:embed openai/responses/blocking/custom_tool.txtar
OaiResponsesBlockingCustomTool []byte
Expand All @@ -65,6 +65,12 @@ var (

//go:embed openai/responses/blocking/wrong_response_format.txtar
OaiResponsesBlockingWrongResponseFormat []byte

//go:embed openai/responses/blocking/single_injected_tool.txtar
OaiResponsesSingleInjectedTool []byte

//go:embed openai/responses/blocking/single_injected_tool_error.txtar
OaiResponsesSingleInjectedToolError []byte
)

var (
Expand Down
1,522 changes: 1,522 additions & 0 deletions fixtures/openai/responses/blocking/single_injected_tool.txtar

Large diffs are not rendered by default.

1,522 changes: 1,522 additions & 0 deletions fixtures/openai/responses/blocking/single_injected_tool_error.txtar

Large diffs are not rendered by default.

30 changes: 10 additions & 20 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import (
"github.com/google/uuid"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
oaiconst "github.com/openai/openai-go/v3/shared/constant"
"github.com/openai/openai-go/v3/shared/constant"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

const (
Expand All @@ -42,6 +42,7 @@ type responsesInterceptionBase struct {
reqPayload []byte
cfg config.OpenAI
model string
tracer trace.Tracer
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
logger slog.Logger
Expand Down Expand Up @@ -97,18 +98,6 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
return err
}

// keeping the same logic for 'parallel_tool_calls' as in chat-completions
// https://github.com/coder/aibridge/blob/7535a71e91a1d214a31a9b59bb810befb26141bc/intercept/chatcompletions/streaming.go#L99
if len(i.req.Tools) > 0 {
var err error
i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false)
if err != nil {
err = fmt.Errorf("failed set parallel_tool_calls parameter: %w", err)
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
return err
}
}

return nil
}

Expand Down Expand Up @@ -174,7 +163,7 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) {
inputItems := gjson.GetBytes(i.reqPayload, "input").Array()
for i := len(inputItems) - 1; i >= 0; i-- {
item := inputItems[i]
if item.Get("role").Str == "user" {
if item.Get("role").Str == string(constant.ValueOf[constant.User]()) {
var sb strings.Builder

// content can be a string or array of objects:
Expand All @@ -194,7 +183,8 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) {
}
}

return "", errors.New("failed to find last user prompt")
// Request was likely not human-initiated.
return "", nil
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) {
Expand All @@ -204,8 +194,8 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon
return
}

// No prompt found: last request was not human-initiated.
if prompt == "" {
i.logger.Warn(ctx, "got empty last prompt, skipping prompt recording")
return
}

Expand All @@ -224,7 +214,7 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon
}
}

func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, response *responses.Response) {
func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Context, response *responses.Response) {
if response == nil {
i.logger.Warn(ctx, "got empty response, skipping tool usage recording")
return
Expand All @@ -235,9 +225,9 @@ func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, respons

// recording other function types to be considered: https://github.com/coder/aibridge/issues/121
switch item.Type {
case string(oaiconst.ValueOf[oaiconst.FunctionCall]()):
case string(constant.ValueOf[constant.FunctionCall]()):
args = i.parseFunctionCallJSONArgs(ctx, item.Arguments)
case string(oaiconst.ValueOf[oaiconst.CustomToolCall]()):
case string(constant.ValueOf[constant.CustomToolCall]()):
args = item.Input
default:
continue
Expand Down
24 changes: 4 additions & 20 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestLastUserPrompt(t *testing.T) {
},
{
name: "array_single_input_string",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingBuiltinTool),
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool),
expected: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
},
{
Expand Down Expand Up @@ -71,45 +71,30 @@ func TestLastUserPromptErr(t *testing.T) {
require.Contains(t, "cannot get last user prompt: nil struct", err.Error())
})

t.Run("nil_struct", func(t *testing.T) {
t.Parallel()

base := responsesInterceptionBase{}
prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.Empty(t, prompt)
require.Contains(t, "cannot get last user prompt: nil req struct", err.Error())
})

// Other cases where the user prompt might be empty.
tests := []struct {
name string
reqPayload []byte
wantErrMsg string
}{
{
name: "empty_input",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "no_user_role",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_empty_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_empty_content_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_non_input_text_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
wantErrMsg: "failed to find last user prompt",
},
}

Expand All @@ -127,9 +112,8 @@ func TestLastUserPromptErr(t *testing.T) {
}

prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.NoError(t, err)
require.Empty(t, prompt)
require.Contains(t, tc.wantErrMsg, err.Error())
})
}
}
Expand Down Expand Up @@ -318,7 +302,7 @@ func TestRecordToolUsage(t *testing.T) {
logger: slog.Make(),
}

base.recordToolUsage(t.Context(), tc.response)
base.recordNonInjectedToolUsage(t.Context(), tc.response)

tools := rec.RecordedToolUsages()
require.Len(t, tools, len(tc.expected))
Expand Down
95 changes: 87 additions & 8 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
package responses

import (
"context"
"errors"
"fmt"
"net/http"
"time"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/recorder"
"github.com/google/uuid"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

type BlockingResponsesInterceptor struct {
responsesInterceptionBase
}

func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string) *BlockingResponsesInterceptor {
func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *BlockingResponsesInterceptor {
return &BlockingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
id: id,
req: req,
reqPayload: reqPayload,
cfg: cfg,
model: model,
tracer: tracer,
},
}
}
Expand All @@ -46,16 +54,54 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
return err
}

srv := i.newResponsesService()
var respCopy responseCopier
i.injectTools()
i.disableParallelToolCalls()

opts := i.requestOptions(&respCopy)
response, upstreamErr := srv.New(ctx, i.req.ResponseNewParams, opts...)
var (
response *responses.Response
upstreamErr error
respCopy responseCopier
)

// response could be nil eg. fixtures/openai/responses/blocking/wrong_response_format.txtar
if response != nil {
for {
srv := i.newResponsesService()
respCopy = responseCopier{}

opts := i.requestOptions(&respCopy)
opts = append(opts, option.WithRequestTimeout(time.Second*600))
response, upstreamErr = srv.New(ctx, i.req.ResponseNewParams, opts...)

if upstreamErr != nil {
break
}

// response could be nil eg. fixtures/openai/responses/blocking/wrong_response_format.txtar
if response == nil {
break
}

// Record prompt usage on first successful response.
i.recordUserPrompt(ctx, response.ID)
i.recordToolUsage(ctx, response)

// Check if there any injected tools to invoke.
pending := i.getPendingInjectedToolCalls(ctx, response)
if len(pending) == 0 {
// No injected tools, record non-injected tool usage.
i.recordNonInjectedToolUsage(ctx, response)

// No injected function calls need to be invoked, flow is complete.
break
}

shouldLoop, err := i.handleInnerAgenticLoop(ctx, pending, response)
if err != nil {
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
shouldLoop = false
}

if !shouldLoop {
break
}
}

if upstreamErr != nil && !respCopy.responseReceived.Load() {
Expand All @@ -67,3 +113,36 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *

return errors.Join(upstreamErr, err)
}

// handleInnerAgenticLoop orchestrates the inner agentic loop whereby injected tools
// are invoked and their results are sent back to the model.
// This is in contrast to regular tool calls which will be handled by the client
// in its own agentic loop.
func (i *BlockingResponsesInterceptor) handleInnerAgenticLoop(ctx context.Context, pending []responses.ResponseFunctionToolCall, response *responses.Response) (bool, error) {
// Invoke any injected function calls.
// The Responses API refers to what we call "tools" as "functions", so we keep the terminology
// consistent in this package.
// See https://platform.openai.com/docs/guides/function-calling
results, err := i.handleInjectedToolCalls(ctx, pending, response)
if err != nil {
return false, fmt.Errorf("failed to handle injected tool calls: %w", err)
}

// No tool results means no tools were invocable, so the flow is complete.
if len(results) == 0 {
return false, nil
}

// We'll use the tool results to issue another request to provide the model with.
i.prepareRequestForAgenticLoop(response)
i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, results...)

i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input", i.req.Input)
if err != nil {
i.logger.Error(ctx, "failure to marshal new input in inner agentic loop", slog.Error(err))
// TODO: what should be returned under this condition?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs some discussion.

return false, nil
}

return true, nil
}
Loading