Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion crates/api/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,48 @@ pub struct ModelAttestation {
pub info: Option<serde_json::Value>,
}

/// Agent attestation from agent instance
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct AgentAttestation {
/// Agent instance name
pub name: String,

/// Container image digest
#[serde(skip_serializing_if = "Option::is_none")]
pub image_digest: Option<String>,

/// TDX event log
#[serde(skip_serializing_if = "Option::is_none")]
pub event_log: Option<String>,

/// Additional TDX/tappd info (structured JSON)
#[serde(skip_serializing_if = "Option::is_none")]
pub info: Option<serde_json::Value>,

/// Intel TDX quote in hex format
#[serde(skip_serializing_if = "Option::is_none")]
pub intel_quote: Option<String>,

/// Request nonce
#[serde(skip_serializing_if = "Option::is_none")]
pub request_nonce: Option<String>,

/// TLS certificate
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_certificate: Option<String>,

/// TLS certificate fingerprint
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_certificate_fingerprint: Option<String>,
}

/// Complete attestation report combining all layers
///
/// This report proves the entire trust chain:
/// 1. This chat-api service runs in a TEE (your_gateway_attestation)
/// 2. The cloud-api dependency runs in a TEE (cloud_api_gateway_attestation)
/// 2. The cloud-api dependency runs in a TEE (cloud_api_gateway_attestation)
/// 3. The model inference providers run on trusted hardware (model_attestations)
/// 4. Optional agent instance attestations when agent parameter is provided
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct CombinedAttestationReport {
/// This chat-api's own CPU attestation (proves this service runs in a TEE)
Expand All @@ -161,6 +197,10 @@ pub struct CombinedAttestationReport {
/// Model provider attestations (can be multiple when routing to different models)
#[serde(skip_serializing_if = "Option::is_none")]
pub model_attestations: Option<Vec<ModelAttestation>>,

/// Agent instance attestations (included when agent query parameter is provided)
#[serde(skip_serializing_if = "Option::is_none")]
pub agent_attestations: Option<Vec<AgentAttestation>>,
}

/// Attestation report structure from proxy_service
Expand Down
244 changes: 238 additions & 6 deletions crates/api/src/routes/attestation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@ use crate::{
state::AppState,
ApiError,
};
use axum::{extract::Query, extract::State, response::Json, routing::get, Router};
use axum::{
extract::{Query, State},
response::Json,
routing::get,
Router,
};
use futures::TryStreamExt;
use http::Method;
use serde::{Deserialize, Serialize};
use services::vpc::load_vpc_info;

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AttestationQuery {
/// Optional model name to get specific attestations
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -27,6 +32,10 @@ pub struct AttestationQuery {
/// Signing address
#[serde(skip_serializing_if = "Option::is_none")]
pub signing_address: Option<String>,

/// Optional agent instance ID to get agent attestations
#[serde(skip_serializing_if = "Option::is_none")]
pub agent: Option<String>,
}

/// GET /v1/attestation/report
Expand All @@ -46,7 +55,8 @@ pub struct AttestationQuery {
("model" = Option<String>, Query, description = "Optional model name to filter model attestations"),
("signing_algo" = Option<String>, Query, description = "Signing algorithm: 'ecdsa' or 'ed25519'"),
("nonce" = Option<String>, Query, description = "64 length (32 bytes) hex string"),
("signing_address" = Option<String>, Query, description = "Query the attestation of the specific model that owns this signing address")
("signing_address" = Option<String>, Query, description = "Query the attestation of the specific model that owns this signing address"),
("agent" = Option<String>, Query, description = "Optional agent instance ID to include agent attestations in response")
),
responses(
(status = 200, description = "Combined attestation report", body = CombinedAttestationReport),
Expand All @@ -58,7 +68,11 @@ pub async fn get_attestation_report(
State(app_state): State<AppState>,
Query(params): Query<AttestationQuery>,
) -> Result<Json<CombinedAttestationReport>, ApiError> {
let query = serde_urlencoded::to_string(&params).expect("Failed to serialize query string");
// Exclude agent parameter from cloud-api query since it's not relevant there
let mut cloud_api_params = params.clone();
cloud_api_params.agent = None;
let query =
serde_urlencoded::to_string(&cloud_api_params).expect("Failed to serialize query string");

// Build the path for proxy_service attestation endpoint
let path = format!("attestation/report?{}", query);
Expand Down Expand Up @@ -131,7 +145,7 @@ pub async fn get_attestation_report(
signing_algo: None,
intel_quote: "0x1234567890abcdef".to_string(),
event_log: None,
request_nonce,
request_nonce: request_nonce.clone(),
info: None,
vpc: vpc_info,
}
Expand Down Expand Up @@ -163,7 +177,7 @@ pub async fn get_attestation_report(
info: Some(serde_json::to_value(info).map_err(|_| {
ApiError::internal_server_error("Failed to serialize attestation info")
})?),
request_nonce,
request_nonce: request_nonce.clone(),
vpc: vpc_info,
}
};
Expand All @@ -172,15 +186,233 @@ pub async fn get_attestation_report(

let model_attestations = proxy_report.model_attestations;

// Fetch agent attestations if agent parameter is provided (no user auth required)
let agent_attestations = if let Some(agent_id) = &params.agent {
match fetch_agent_attestations(&app_state, agent_id, &request_nonce).await {
Ok(attestations) => Some(attestations),
Err(e) => {
tracing::warn!("Failed to fetch agent attestations: {:?}", e);
// Don't fail the entire request if agent attestation fetch fails
None
}
}
} else {
None
};

let report = CombinedAttestationReport {
chat_api_gateway_attestation,
cloud_api_gateway_attestation,
model_attestations,
agent_attestations,
};

Ok(Json(report))
}

/// Fetch agent attestations from compose-api
#[derive(Debug, Deserialize)]
struct AgentAttestationResponse {
event_log: Option<String>,
quote: Option<String>,
#[serde(default)]
info: Option<serde_json::Value>,
tls_certificate: Option<String>,
tls_certificate_fingerprint: Option<String>,
}

#[derive(Debug, Deserialize)]
struct AgentInstanceAttestationResponse {
image_digest: Option<String>,
name: String,
}

/// Validate nonce is properly formatted and reasonable length (replay protection)
fn validate_nonce(nonce: &str) -> Result<(), ApiError> {
// Nonce should be a valid hex string of reasonable length (64 chars = 32 bytes)
const EXPECTED_NONCE_LEN: usize = 64;
const MAX_NONCE_LEN: usize = 256;

if nonce.len() > MAX_NONCE_LEN {
tracing::warn!("Nonce exceeds maximum length: {}", nonce.len());
return Err(ApiError::bad_request("Nonce is too long"));
}

if !nonce.chars().all(|c| c.is_ascii_hexdigit()) {
tracing::warn!("Nonce contains non-hex characters");
return Err(ApiError::bad_request("Nonce must be a valid hex string"));
}

if nonce.len() != EXPECTED_NONCE_LEN {
tracing::warn!(
"Nonce has unexpected length: {} (expected {})",
nonce.len(),
EXPECTED_NONCE_LEN
);
return Err(ApiError::bad_request(format!(
"Nonce must be exactly {} characters",
EXPECTED_NONCE_LEN
)));
}

Ok(())
}

/// Validate instance name doesn't contain path traversal sequences
fn validate_instance_name(name: &str) -> Result<(), ApiError> {
// Reject names containing path traversal sequences
if name.contains("..") || name.contains("/") || name.contains("\\") {
tracing::warn!("Instance name contains invalid characters: {}", name);
return Err(ApiError::bad_request(
"Instance name contains invalid characters",
));
}

if name.is_empty() {
return Err(ApiError::bad_request("Instance name cannot be empty"));
}

Ok(())
}

/// Helper function to handle HTTP response from proxy service (DRY)
async fn handle_proxy_response(
response: services::response::ports::ProxyResponse,
context: &str,
) -> Result<bytes::Bytes, ApiError> {
if response.status < 200 || response.status >= 300 {
tracing::error!(
"Proxy service returned error status {} for {}",
response.status,
context
);
return Err(ApiError::service_unavailable(format!(
"{} service returned error: {}",
context, response.status
)));
}

Ok(response
.body
.try_collect::<Vec<_>>()
.await
.map_err(|e| {
tracing::error!("Failed to read {} response: {}", context, e);
ApiError::internal_server_error(format!("Failed to read {} response", context))
})?
.into_iter()
.flatten()
.collect())
}

async fn fetch_agent_attestations(
app_state: &AppState,
agent_id: &str,
request_nonce: &str,
) -> Result<Vec<crate::models::AgentAttestation>, ApiError> {
use uuid::Uuid;

// Security: Validate nonce to prevent panic/DoS from malformed input
validate_nonce(request_nonce)?;

// Parse the agent_id as UUID
let agent_uuid = Uuid::parse_str(agent_id).map_err(|e| {
tracing::error!("Invalid agent ID format: {}", e);
ApiError::bad_request(format!("Invalid agent ID format: {}", e))
})?;

// Fetch the agent instance from database (no user_id check - attestation is public)
let agent_instance = app_state
.agent_repository
.get_instance(agent_uuid)
.await
.map_err(|e| {
tracing::error!("Failed to fetch agent instance from database: {}", e);
ApiError::internal_server_error("Failed to fetch agent instance")
})?
.ok_or_else(|| {
tracing::warn!("Agent instance not found: {}", agent_id);
ApiError::not_found("Agent instance not found")
})?;

// Security: Validate instance name to prevent path traversal attacks
validate_instance_name(&agent_instance.name)?;

let instance_name = &agent_instance.name;

// URL-encode instance name for safe URL construction
let encoded_instance_name = urlencoding::encode(instance_name);

// Build paths for both requests
// NOTE: Nonce is critical for replay protection - bind the quote to the client's nonce
let attestation_path = format!("attestation/report?nonce={}", request_nonce);
let instance_attestation_path = format!("instances/{}/attestation", encoded_instance_name);

// Fetch both attestations concurrently to minimize latency
let (attestation_response, instance_response) = tokio::join!(
app_state.proxy_service.forward_request(
Method::GET,
&attestation_path,
http::HeaderMap::new(),
None
),
app_state.proxy_service.forward_request(
Method::GET,
&instance_attestation_path,
http::HeaderMap::new(),
None,
)
);

// Handle responses
let attestation_response = attestation_response.map_err(|e| {
tracing::error!(
"Failed to fetch agent attestation report from compose-api: {}",
e
);
ApiError::bad_gateway(format!("Failed to fetch agent attestation: {}", e))
})?;

let attestation_bytes =
handle_proxy_response(attestation_response, "Agent attestation").await?;

let attestation_data: AgentAttestationResponse = serde_json::from_slice(&attestation_bytes)
.map_err(|e| {
tracing::error!("Failed to parse agent attestation response: {}", e);
ApiError::internal_server_error("Failed to parse agent attestation")
})?;

let instance_response = instance_response.map_err(|e| {
tracing::error!(
"Failed to fetch agent instance attestation from compose-api: {}",
e
);
ApiError::bad_gateway(format!("Failed to fetch instance attestation: {}", e))
})?;

let instance_bytes = handle_proxy_response(instance_response, "Instance attestation").await?;

let instance_data: AgentInstanceAttestationResponse = serde_json::from_slice(&instance_bytes)
.map_err(|e| {
tracing::error!("Failed to parse instance attestation response: {}", e);
ApiError::internal_server_error("Failed to parse instance attestation")
})?;

// Combine the data
let agent_attestation = crate::models::AgentAttestation {
name: instance_data.name,
image_digest: instance_data.image_digest,
event_log: attestation_data.event_log,
info: attestation_data.info,
intel_quote: attestation_data.quote,
request_nonce: Some(request_nonce.to_string()),
tls_certificate: attestation_data.tls_certificate,
tls_certificate_fingerprint: attestation_data.tls_certificate_fingerprint,
};

Ok(vec![agent_attestation])
}

/// Create the attestation router
pub fn create_attestation_router() -> Router<AppState> {
Router::new().route("/v1/attestation/report", get(get_attestation_report))
Expand Down