Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,19 @@ models:
| `role_session_name` | string | Session name for assumed role | cagent-bedrock-session |
| `external_id` | string | External ID for role assumption | (none) |
| `endpoint_url` | string | Custom endpoint (VPC/testing) | (none) |
| `interleaved_thinking` | bool | Enable reasoning during tool calls (requires thinking_budget) | false |
| `disable_prompt_caching` | bool | Disable automatic prompt caching | false |

#### Prompt Caching (Bedrock)

Prompt caching is automatically enabled for models that support it (detected via models.dev) to reduce latency and costs. System prompts, tool definitions, and recent messages are cached with a 5-minute TTL.

To disable:

```yaml
provider_opts:
disable_prompt_caching: true
```

**Supported models (via Converse API):**

Expand Down
69 changes: 53 additions & 16 deletions pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ import (
"github.com/docker/cagent/pkg/environment"
"github.com/docker/cagent/pkg/model/provider/base"
"github.com/docker/cagent/pkg/model/provider/options"
"github.com/docker/cagent/pkg/modelsdev"
"github.com/docker/cagent/pkg/tools"
)

// Client represents a Bedrock client wrapper implementing provider.Provider
type Client struct {
base.Config
bedrockClient *bedrockruntime.Client
bedrockClient *bedrockruntime.Client
cachingSupported bool // Cached at init time for efficiency
}

// bearerTokenTransport adds Authorization header with bearer token to requests
Expand All @@ -40,7 +42,6 @@ func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, err
return t.base.RoundTrip(req)
}

// NewClient creates a new Bedrock client from the provided configuration
func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (*Client, error) {
if cfg == nil {
slog.Error("Bedrock client creation failed", "error", "model configuration is required")
Expand Down Expand Up @@ -109,19 +110,47 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro

bedrockClient := bedrockruntime.NewFromConfig(awsCfg, clientOpts...)

slog.Debug("Bedrock client created successfully", "model", cfg.Model, "region", awsCfg.Region)
// Detect prompt caching capability at init time for efficiency.
// Uses models.dev cache pricing as proxy for capability detection.
cachingSupported := detectCachingSupport(ctx, cfg.Model)

slog.Debug("Bedrock client created successfully",
"model", cfg.Model,
"region", awsCfg.Region,
"caching_supported", cachingSupported)

return &Client{
Config: base.Config{
ModelConfig: *cfg,
ModelOptions: globalOptions,
Env: env,
},
bedrockClient: bedrockClient,
bedrockClient: bedrockClient,
cachingSupported: cachingSupported,
}, nil
}

// buildAWSConfig creates AWS config with proper credentials using the default credential chain
// detectCachingSupport checks if a model supports prompt caching using models.dev data.
// Models with non-zero CacheRead/CacheWrite costs support prompt caching.
// Returns false on lookup failure (safe default for unsupported models).
func detectCachingSupport(ctx context.Context, model string) bool {
store, err := modelsdev.NewStore()
if err != nil {
slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err)
return false
}

modelID := "amazon-bedrock/" + model
m, err := store.GetModel(ctx, modelID)
if err != nil {
slog.Debug("Bedrock prompt caching disabled: model not found in models.dev",
"model_id", modelID, "error", err)
return false
}

return m.Cost != nil && (m.Cost.CacheRead > 0 || m.Cost.CacheWrite > 0)
}

func buildAWSConfig(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider) (aws.Config, error) {
var configOpts []func(*config.LoadOptions) error

Expand Down Expand Up @@ -169,7 +198,6 @@ func buildAWSConfig(ctx context.Context, cfg *latest.ModelConfig, env environmen
return awsCfg, nil
}

// CreateChatCompletionStream creates a streaming chat completion request
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
messages []chat.Message,
Expand Down Expand Up @@ -198,21 +226,22 @@ func (c *Client) CreateChatCompletionStream(
return newStreamAdapter(output.GetStream(), c.ModelConfig.Model, trackUsage), nil
}

// buildConverseStreamInput creates the ConverseStream input parameters
func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools []tools.Tool) *bedrockruntime.ConverseStreamInput {
input := &bedrockruntime.ConverseStreamInput{
ModelId: aws.String(c.ModelConfig.Model),
}

enableCaching := c.promptCachingEnabled()

// Convert and set messages (excluding system)
input.Messages, input.System = convertMessages(messages)
input.Messages, input.System = convertMessages(messages, enableCaching)

// Set inference configuration
input.InferenceConfig = c.buildInferenceConfig()

// Convert and set tools
if len(requestTools) > 0 {
input.ToolConfig = convertToolConfig(requestTools)
input.ToolConfig = convertToolConfig(requestTools, enableCaching)
}

// Set extended thinking configuration for Claude models
Expand All @@ -223,7 +252,6 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
return input
}

// buildInferenceConfig creates the inference configuration
func (c *Client) buildInferenceConfig() *types.InferenceConfiguration {
cfg := &types.InferenceConfiguration{}

Expand All @@ -247,8 +275,8 @@ func (c *Client) buildInferenceConfig() *types.InferenceConfiguration {
return cfg
}

// isThinkingEnabled checks if extended thinking will be enabled for this request.
// This mirrors the validation logic in buildAdditionalModelRequestFields.
// isThinkingEnabled mirrors the validation in buildAdditionalModelRequestFields
// to determine if thinking params will affect inference config (temp/topP constraints).
func (c *Client) isThinkingEnabled() bool {
if c.ModelConfig.ThinkingBudget == nil || c.ModelConfig.ThinkingBudget.Tokens <= 0 {
return false
Expand All @@ -269,13 +297,18 @@ func (c *Client) isThinkingEnabled() bool {
return true
}

// interleavedThinkingEnabled returns true when provider_opts.interleaved_thinking is set.
func (c *Client) interleavedThinkingEnabled() bool {
return getProviderOpt[bool](c.ModelConfig.ProviderOpts, "interleaved_thinking")
}

// buildAdditionalModelRequestFields creates model-specific parameters.
// Used for extended thinking (reasoning) configuration on Claude models.
func (c *Client) promptCachingEnabled() bool {
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
return false
}
return c.cachingSupported
}

// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode.
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
if c.ModelConfig.ThinkingBudget == nil || c.ModelConfig.ThinkingBudget.Tokens <= 0 {
return nil
Expand Down Expand Up @@ -316,7 +349,6 @@ func (c *Client) buildAdditionalModelRequestFields() document.Interface {
return document.NewLazyDocument(fields)
}

// getProviderOpt extracts a typed value from provider_opts
func getProviderOpt[T any](opts map[string]any, key string) T {
var zero T
if opts == nil {
Expand All @@ -328,6 +360,11 @@ func getProviderOpt[T any](opts map[string]any, key string) T {
}
typed, ok := v.(T)
if !ok {
slog.Warn("Bedrock provider_opts type mismatch",
"key", key,
"expected_type", fmt.Sprintf("%T", zero),
"actual_type", fmt.Sprintf("%T", v),
"value", v)
return zero
}
return typed
Expand Down
Loading