Skip to content

feat: Add custom model support with AWS credentials #2538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,28 @@ book/
.env*

run-build.sh

# Claude Flow generated files
.claude/settings.local.json
.mcp.json
claude-flow.config.json
.swarm/
.hive-mind/
memory/claude-flow-data.json
memory/sessions/*
!memory/sessions/README.md
memory/agents/*
!memory/agents/README.md
coordination/memory_bank/*
coordination/subtasks/*
coordination/orchestration/*
*.db
*.db-journal
*.db-wal
*.sqlite
*.sqlite-journal
*.sqlite-wal
claude-flow
claude-flow.bat
claude-flow.ps1
hive-mind-prompt-*.txt
191 changes: 191 additions & 0 deletions crates/chat-cli/src/api_client/custom_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use aws_credential_types::provider::ProvideCredentials;
use tracing::{
debug,
info,
};

use crate::api_client::credentials::CredentialsChain;

/// Parse custom model format: custom:<region>:<actual-model-id>
/// Example: custom:us-east-1:us.anthropic.claude-3-5-sonnet-20241022-v2:0
fn parse_custom_model(model_id: &str) -> Option<(String, String)> {
if !model_id.starts_with("custom:") {
return None;
}

// Remove "custom:" prefix
let without_prefix = &model_id[7..];

// Find the first colon to separate region from model ID
if let Some(colon_pos) = without_prefix.find(':') {
let region = without_prefix[..colon_pos].to_string();
let actual_model_id = without_prefix[colon_pos + 1..].to_string();

return Some((region, actual_model_id));
}

None
}

/// Handle custom model requests using AWS credentials
pub struct CustomModelHandler {
pub region: String,
pub actual_model_id: String,
}

impl CustomModelHandler {
/// Parse a custom model ID string
/// Format: custom:<region>:<actual-model-id>
/// Example: custom:us-east-1:us.anthropic.claude-3-5-sonnet-20241022-v2:0
pub fn from_model_id(model_id: &str) -> Option<Self> {
parse_custom_model(model_id).map(|(region, actual_model_id)| Self {
region,
actual_model_id,
})
}

/// Check if this is a Bedrock/Anthropic model
#[allow(dead_code)]
pub fn is_bedrock(&self) -> bool {
self.actual_model_id.contains("anthropic") || self.actual_model_id.contains("claude")
}

/// Get the actual model ID for API calls (without custom: prefix)
pub fn get_model_id(&self) -> &str {
&self.actual_model_id
}

/// Set environment to use AWS credentials
pub fn setup_aws_auth(&self) {
// Set the environment variable to use SigV4 authentication
// Note: Using unsafe as required for dynamic configuration
unsafe {
std::env::set_var("AMAZON_Q_SIGV4", "1");

// Set the region if specified
if !self.region.is_empty() {
std::env::set_var("AWS_REGION", &self.region);
}
}

info!(
"Configured custom model with AWS authentication: region={}, model={}",
self.region, self.actual_model_id
);
}

/// Validate that AWS credentials are available
#[allow(dead_code)]
pub async fn validate_credentials() -> Result<(), String> {
let credentials_chain = CredentialsChain::new().await;
match credentials_chain.provide_credentials().await {
Ok(_) => {
debug!("AWS credentials validated successfully");
Ok(())
},
Err(e) => Err(format!("Failed to get AWS credentials: {}", e)),
}
}
}

// Add Debug trait implementation for better test output
impl std::fmt::Debug for CustomModelHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomModelHandler")
.field("region", &self.region)
.field("actual_model_id", &self.actual_model_id)
.finish()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_custom_model_handler_creation() {
let handler = CustomModelHandler {
region: "us-east-1".to_string(),
actual_model_id: "CLAUDE_3_7_SONNET_20250219_V1_0".to_string(),
};

assert_eq!(handler.region, "us-east-1");
assert_eq!(handler.actual_model_id, "CLAUDE_3_7_SONNET_20250219_V1_0");
}

#[test]
fn test_from_model_id() {
let handler = CustomModelHandler::from_model_id("custom:us-west-2:test-model-id");
assert!(handler.is_some());
let handler = handler.unwrap();
assert_eq!(handler.region, "us-west-2");
assert_eq!(handler.actual_model_id, "test-model-id");
}

#[test]
fn test_is_bedrock() {
let handler1 = CustomModelHandler {
region: "us-east-1".to_string(),
actual_model_id: "anthropic.claude-3-5-sonnet".to_string(),
};
assert!(handler1.is_bedrock());

let handler2 = CustomModelHandler {
region: "us-east-1".to_string(),
actual_model_id: "claude-4-sonnet".to_string(),
};
assert!(handler2.is_bedrock());

let handler3 = CustomModelHandler {
region: "us-east-1".to_string(),
actual_model_id: "other-model".to_string(),
};
assert!(!handler3.is_bedrock());
}

#[test]
fn test_get_model_id() {
let handler = CustomModelHandler {
region: "eu-west-1".to_string(),
actual_model_id: "CLAUDE_SONNET_4_20250514_V1_0".to_string(),
};
assert_eq!(handler.get_model_id(), "CLAUDE_SONNET_4_20250514_V1_0");
}

#[test]
fn test_parse_custom_model() {
// Valid format
let result = parse_custom_model("custom:us-east-1:model-id");
assert!(result.is_some());
let (region, model) = result.unwrap();
assert_eq!(region, "us-east-1");
assert_eq!(model, "model-id");

// Invalid formats
assert!(parse_custom_model("invalid:format").is_none());
assert!(parse_custom_model("custom:").is_none());
assert!(parse_custom_model("custom:us-east-1").is_none());
assert!(parse_custom_model("").is_none());
}

#[test]
fn test_complex_model_ids() {
let result = parse_custom_model("custom:us-east-1:vendor:model:version:0");
assert!(result.is_some());
let (region, model) = result.unwrap();
assert_eq!(region, "us-east-1");
assert_eq!(model, "vendor:model:version:0");
}

#[test]
fn test_debug_trait() {
let handler = CustomModelHandler {
region: "ap-southeast-1".to_string(),
actual_model_id: "TEST_MODEL".to_string(),
};
let debug_str = format!("{:?}", handler);
assert!(debug_str.contains("CustomModelHandler"));
assert!(debug_str.contains("ap-southeast-1"));
assert!(debug_str.contains("TEST_MODEL"));
}
}
27 changes: 25 additions & 2 deletions crates/chat-cli/src/api_client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod credentials;
pub mod custom_model;
pub mod customization;
mod endpoints;
mod error;
Expand Down Expand Up @@ -35,6 +36,7 @@ use serde_json::Map;
use tracing::{
debug,
error,
info,
};

use crate::api_client::credentials::CredentialsChain;
Expand Down Expand Up @@ -85,6 +87,9 @@ impl ApiClient {
) -> Result<Self, ApiClientError> {
let endpoint = endpoint.unwrap_or(Endpoint::configured_value(database));

// Check if using custom model (bypasses authentication)
let _use_custom_model = env.get("AMAZON_Q_CUSTOM_MODEL").is_ok() || env.get("AMAZON_Q_SIGV4").is_ok();

let credentials = Credentials::new("xxx", "xxx", None, None, "xxx");
let bearer_sdk_config = aws_config::defaults(behavior_version())
.region(endpoint.region.clone())
Expand Down Expand Up @@ -121,10 +126,28 @@ impl ApiClient {
return Ok(this);
}

// If SIGV4_AUTH_ENABLED is true, use Q developer client
// Check if using custom model first
let custom_model = database
.settings
.get_string(Setting::ChatDefaultModel)
.and_then(|m| custom_model::CustomModelHandler::from_model_id(&m))
.or_else(|| {
// Also check environment variable
std::env::var("AMAZON_Q_MODEL")
.ok()
.and_then(|m| custom_model::CustomModelHandler::from_model_id(&m))
});

if let Some(ref cm) = custom_model {
// Setup AWS authentication for custom model
cm.setup_aws_auth();
info!("Using custom model: {} in region: {}", cm.get_model_id(), cm.region);
}

// If SIGV4_AUTH_ENABLED is true or using custom model, use Q developer client
let mut streaming_client = None;
let mut sigv4_streaming_client = None;
match env.get("AMAZON_Q_SIGV4").is_ok() {
match env.get("AMAZON_Q_SIGV4").is_ok() || custom_model.is_some() {
true => {
let credentials_chain = CredentialsChain::new().await;
if let Err(err) = credentials_chain.provide_credentials().await {
Expand Down
Loading