diff --git a/internal/adapter/provider/cliproxyapi_codex/adapter.go b/internal/adapter/provider/cliproxyapi_codex/adapter.go index e05a48c3..f102053c 100644 --- a/internal/adapter/provider/cliproxyapi_codex/adapter.go +++ b/internal/adapter/provider/cliproxyapi_codex/adapter.go @@ -14,6 +14,7 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/provider" + "github.com/awsl-project/maxx/internal/codexutil" "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" @@ -21,8 +22,6 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" "github.com/router-for-me/CLIProxyAPI/v6/sdk/exec" "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) // TokenCache caches access tokens @@ -254,26 +253,7 @@ func (a *CLIProxyAPICodexAdapter) Execute(c *flow.Ctx, p *domain.Provider) error } func sanitizeCodexPayload(body []byte) []byte { - if input := gjson.GetBytes(body, "input"); input.IsArray() { - for i, item := range input.Array() { - itemType := item.Get("type").String() - if itemType != "message" { - if item.Get("role").Exists() { - body, _ = sjson.DeleteBytes(body, fmt.Sprintf("input.%d.role", i)) - } - } - if itemType == "function_call" { - if id := item.Get("id").String(); id != "" && !strings.HasPrefix(id, "fc_") { - body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.id", i), "fc_"+id) - } - } - if itemType == "function_call_output" { - if !item.Get("output").Exists() { - body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.output", i), "") - } - } - } - } + body = codexutil.NormalizeCodexInput(body) return body } diff --git a/internal/adapter/provider/codex/adapter.go b/internal/adapter/provider/codex/adapter.go index 2295ce21..824b14b5 100644 --- a/internal/adapter/provider/codex/adapter.go +++ b/internal/adapter/provider/codex/adapter.go @@ -16,6 +16,7 @@ import ( "github.com/awsl-project/maxx/internal/adapter/provider" cliproxyapi "github.com/awsl-project/maxx/internal/adapter/provider/cliproxyapi_codex" + "github.com/awsl-project/maxx/internal/codexutil" "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" @@ -572,26 +573,7 @@ func applyCodexRequestTuning(c *flow.Ctx, body []byte) (string, []byte) { if !gjson.GetBytes(body, "instructions").Exists() { body, _ = sjson.SetBytes(body, "instructions", "") } - if input := gjson.GetBytes(body, "input"); input.IsArray() { - for i, item := range input.Array() { - itemType := item.Get("type").String() - if itemType != "message" { - if item.Get("role").Exists() { - body, _ = sjson.DeleteBytes(body, fmt.Sprintf("input.%d.role", i)) - } - } - if itemType == "function_call" { - if id := item.Get("id").String(); id != "" && !strings.HasPrefix(id, "fc_") { - body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.id", i), "fc_"+id) - } - } - if itemType == "function_call_output" { - if !item.Get("output").Exists() { - body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.output", i), "") - } - } - } - } + body = codexutil.NormalizeCodexInput(body) return cacheID, body } diff --git a/internal/adapter/provider/codex/adapter_test.go b/internal/adapter/provider/codex/adapter_test.go index 5a1e6da2..dfafe67b 100644 --- a/internal/adapter/provider/codex/adapter_test.go +++ b/internal/adapter/provider/codex/adapter_test.go @@ -2,6 +2,7 @@ package codex import ( "net/http" + "strings" "testing" "github.com/awsl-project/maxx/internal/domain" @@ -14,7 +15,7 @@ func TestApplyCodexRequestTuning(t *testing.T) { c.Set(flow.KeyOriginalClientType, domain.ClientTypeClaude) c.Set(flow.KeyOriginalRequestBody, []byte(`{"metadata":{"user_id":"user-123"}}`)) - body := []byte(`{"model":"gpt-5","stream":false,"instructions":"x","previous_response_id":"r1","prompt_cache_retention":123,"safety_identifier":"s1","max_output_tokens":77,"input":[{"type":"message","role":"user","content":"hi"},{"type":"function_call","role":"assistant","name":"t","arguments":"{}"},{"role":"tool","call_id":"c1","output":"ok"}]}`) + body := []byte(`{"model":"gpt-5","stream":false,"instructions":"x","previous_response_id":"r1","prompt_cache_retention":123,"safety_identifier":"s1","max_output_tokens":77,"input":[{"type":"message","role":"user","content":"hi"},{"type":"function_call","role":"assistant","name":"t","arguments":"{}","id":"toolu_01"},{"type":"function_call","name":"t2","arguments":"{}"},{"type":"function_call_output","call_id":"c1"},{"role":"tool","call_id":"c1","output":"ok"}]}`) cacheID, tuned := applyCodexRequestTuning(c, body) if cacheID == "" { @@ -44,9 +45,19 @@ func TestApplyCodexRequestTuning(t *testing.T) { if gjson.GetBytes(tuned, "input.0.role").String() != "user" { t.Fatalf("expected role to be preserved for message input") } - if gjson.GetBytes(tuned, "input.1.role").Exists() || gjson.GetBytes(tuned, "input.2.role").Exists() { + if gjson.GetBytes(tuned, "input.1.role").Exists() || gjson.GetBytes(tuned, "input.2.role").Exists() || gjson.GetBytes(tuned, "input.3.role").Exists() || gjson.GetBytes(tuned, "input.4.role").Exists() { t.Fatalf("expected role to be removed for non-message inputs") } + if gjson.GetBytes(tuned, "input.1.id").String() != "fc_toolu_01" { + t.Fatalf("expected function_call id to be prefixed with fc_") + } + missingID := gjson.GetBytes(tuned, "input.2.id").String() + if !strings.HasPrefix(missingID, "fc_") || missingID == "fc_" { + t.Fatalf("expected generated function_call id to be set and prefixed with fc_") + } + if gjson.GetBytes(tuned, "input.3.output").String() != "" { + t.Fatalf("expected missing function_call_output output to default to empty string") + } } func TestApplyCodexHeadersFiltersSensitiveAndPreservesUA(t *testing.T) { diff --git a/internal/codexutil/normalize.go b/internal/codexutil/normalize.go new file mode 100644 index 00000000..8f76c012 --- /dev/null +++ b/internal/codexutil/normalize.go @@ -0,0 +1,43 @@ +package codexutil + +import ( + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +func NormalizeCodexInput(body []byte) []byte { + input := gjson.GetBytes(body, "input") + if !input.IsArray() { + return body + } + + for i, item := range input.Array() { + itemType := item.Get("type").String() + if itemType != "message" { + if item.Get("role").Exists() { + body, _ = sjson.DeleteBytes(body, fmt.Sprintf("input.%d.role", i)) + } + } + if itemType == "function_call" { + id := strings.TrimSpace(item.Get("id").String()) + switch { + case strings.HasPrefix(id, "fc_"): + case id == "": + body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.id", i), "fc_"+uuid.NewString()) + default: + body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.id", i), "fc_"+id) + } + } + if itemType == "function_call_output" { + if !item.Get("output").Exists() { + body, _ = sjson.SetBytes(body, fmt.Sprintf("input.%d.output", i), "") + } + } + } + + return body +}