Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/apollo-mcp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ pub mod operations;
pub mod sanitize;
pub(crate) mod schema_tree_shake;
pub mod server;
pub mod server_config;
pub mod server_handler;
58 changes: 43 additions & 15 deletions crates/apollo-mcp-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use std::path::PathBuf;

use crate::runtime::Serve;
use apollo_mcp_registry::platform_api::operation_collections::collection_poller::CollectionSource;
use apollo_mcp_registry::uplink::persisted_queries::ManifestSource;
use apollo_mcp_registry::uplink::schema::SchemaSource;
use apollo_mcp_server::custom_scalar_map::CustomScalarMap;
use apollo_mcp_server::errors::ServerError;
use apollo_mcp_server::operations::OperationSource;
use apollo_mcp_server::server::Server;
use apollo_mcp_server::server_config::ServerConfig;
use apollo_mcp_server::server_handler::ApolloMcpServerHandler;
use clap::Parser;
use clap::builder::Styles;
use clap::builder::styling::{AnsiColor, Effects};
use runtime::IdOrDefault;
use runtime::logging::Logging;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};

mod runtime;
Expand Down Expand Up @@ -109,19 +114,18 @@ async fn main() -> anyhow::Result<()> {
.then(|| config.graphos.graph_ref())
.transpose()?;

Ok(Server::builder()
.transport(config.transport)
.schema_source(schema_source)
.operation_source(operation_source)
.endpoint(config.endpoint.into_inner())
let server_handler = ApolloMcpServerHandler::new(config.headers.clone(), config.endpoint.clone());
let cancellation_token = CancellationToken::new();

let server_config = ServerConfig::builder()
.maybe_explorer_graph_ref(explorer_graph_ref)
.headers(config.headers)
.execute_introspection(config.introspection.execute.enabled)
.validate_introspection(config.introspection.validate.enabled)
.introspect_introspection(config.introspection.introspect.enabled)
.headers(config.headers.clone())
.execute_enabled(config.introspection.execute.enabled)
.validate_enabled(config.introspection.validate.enabled)
.introspect_enabled(config.introspection.introspect.enabled)
.introspect_minify(config.introspection.introspect.minify)
.search_enabled(config.introspection.search.enabled)
.search_minify(config.introspection.search.minify)
.search_introspection(config.introspection.search.enabled)
.mutation_mode(config.overrides.mutation_mode)
.disable_type_description(config.overrides.disable_type_description)
.disable_schema_description(config.overrides.disable_schema_description)
Expand All @@ -133,8 +137,32 @@ async fn main() -> anyhow::Result<()> {
)
.search_leaf_depth(config.introspection.search.leaf_depth)
.index_memory_bytes(config.introspection.search.index_memory_bytes)
.health_check(config.health_check)
.health_check(config.health_check.clone())
.build();
let state_machine = Server::builder()
.schema_source(schema_source)
.operation_source(operation_source)
.server_handler(server_handler.clone())
.cancellation_token(cancellation_token.child_token())
.server_config(server_config)
.build()
.start()
.await?)
.start();

let server = Serve::serve(
server_handler,
config.transport,
cancellation_token,
config.health_check,
);

let (state_machine_result, server_result) = tokio::join!(state_machine, server);

match (state_machine_result, server_result) {
(Ok(()), Ok(())) => {
Ok(())
},
(Err(state_error), Err(server_error)) => anyhow::bail!("Both state machine and server have errors {} | {}", state_error, server_error),
(Err(state_error), _) => Err(state_error.into()),
(_, Err(server_error)) => Err(server_error.into()),
}
}
2 changes: 2 additions & 0 deletions crates/apollo-mcp-server/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod operation_source;
mod overrides;
mod schema_source;
mod schemas;
mod serve;

use std::path::Path;

Expand All @@ -22,6 +23,7 @@ use figment::{
};
pub use operation_source::{IdOrDefault, OperationSource};
pub use schema_source::SchemaSource;
pub use serve::Serve;

/// Separator to use when drilling down into nested options in the env figment
const ENV_NESTED_SEPARATOR: &str = "__";
Expand Down
185 changes: 185 additions & 0 deletions crates/apollo-mcp-server/src/runtime/serve.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use apollo_mcp_server::auth::Config;
use apollo_mcp_server::errors::ServerError;
use apollo_mcp_server::health::{HealthCheck, HealthCheckConfig};
use apollo_mcp_server::server::Transport;
use apollo_mcp_server::server::states::shutdown_signal;
use apollo_mcp_server::server_handler::ApolloMcpServerHandler;
use axum::extract::Query;
use axum::routing::get;
use axum::{Json, Router};
use http::StatusCode;
use rmcp::service::{RunningService, ServerInitializeError};
use rmcp::transport::sse_server::SseServerConfig;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::transport::{SseServer, StreamableHttpService, stdio};
use rmcp::{RoleServer, ServiceExt};
use serde_json::json;
use std::io::Error;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, trace};

// Helper to enable auth
macro_rules! with_auth {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we are going to disable/modify this macro based on feature flags I would just move this to a function.

($router:expr, $auth:ident) => {{
let mut router = $router;
if let Some(auth) = $auth {
router = auth.enable_middleware(router);
}

router
}};
}

pub struct Serve;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The Serve name here is a little confusing as verbs are usually Traits in rust.


impl Serve {
pub async fn serve(
server_handler: ApolloMcpServerHandler,
transport: Transport,
cancellation_token: CancellationToken,
health_check_config: HealthCheckConfig,
) -> Result<(), ServerError> {
match transport {
Transport::StreamableHttp {
auth,
address,
port,
} => {
serve_streamable_http(auth, address, port, server_handler, health_check_config)
.await?;
}
Transport::SSE {
auth,
address,
port,
} => {
serve_sse(auth, address, port, server_handler, cancellation_token).await?;
}
Transport::Stdio => {
let service = serve_stdio(server_handler)
.await
.map_err(|e| ServerError::McpInitializeError(e.into()))?;
service.waiting().await.map_err(ServerError::StartupError)?;
}
}

Ok(())
}
}

// Create health check if enabled (only for StreamableHttp transport)
fn create_health_check(config: HealthCheckConfig) -> Option<HealthCheck> {
// let telemetry: Arc<dyn Telemetry> = Arc::new(InMemoryTelemetry::new());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

@alocay alocay Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is left over from a previous change I was doing, ignore it for now.

Some(HealthCheck::new(config))
}

async fn serve_streamable_http(
auth: Option<Config>,
address: IpAddr,
port: u16,
server_handler: ApolloMcpServerHandler,
health_check_config: HealthCheckConfig,
) -> Result<(), ServerError> {
info!(port = ?port, address = ?address, "Starting MCP server in Streamable HTTP mode");
let listen_address = SocketAddr::new(address, port);
let service = StreamableHttpService::new(
move || Ok(server_handler.clone()),
LocalSessionManager::default().into(),
Default::default(),
);

let mut router = with_auth!(Router::new().nest_service("/mcp", service), auth);

// Add health check endpoint if configured
if health_check_config.enabled {
if let Some(health_check) = create_health_check(health_check_config) {
let health_router = Router::new()
.route(&health_check.config().path, get(health_endpoint))
.with_state(health_check.clone());
router = router.merge(health_router);
}
}

let tcp_listener = tokio::net::TcpListener::bind(listen_address).await?;
tokio::spawn(async move {
// Health check is already active from creation
if let Err(e) = axum::serve(tcp_listener, router)
.with_graceful_shutdown(shutdown_signal())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also use the cancellation token?

.await
{
// This can never really happen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why cant this happen? If it is really impossible we should mark it as unreachable!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just copied over from what was there. I'd have to take a look to see what this comment was referring to.

error!("Failed to start MCP server: {e:?}");
}
});

Ok(())
}

async fn serve_sse(
auth: Option<Config>,
address: IpAddr,
port: u16,
server_handler: ApolloMcpServerHandler,
cancellation_token: CancellationToken,
) -> Result<(), Error> {
info!(port = ?port, address = ?address, "Starting MCP server in SSE mode");
let listen_address = SocketAddr::new(address, port);

let (server, router) = SseServer::new(SseServerConfig {
bind: listen_address,
sse_path: "/sse".to_string(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Implement Default for SseServerConfig then this can be reduced to:

SseServerConfig {
   bind: listen_address, 
   ct: canellation_token,
   ..Default::default()
}

post_path: "/message".to_string(),
ct: cancellation_token,
sse_keep_alive: None,
});

// Optionally wrap the router with auth, if enabled
let router = with_auth!(router, auth);

// Start up the SSE server
// Note: Until RMCP consolidates SSE with the same tower system as StreamableHTTP,
// we need to basically copy the implementation of `SseServer::serve_with_config` here.
let listener = tokio::net::TcpListener::bind(server.config.bind).await?;
let ct = server.config.ct.child_token();
let axum_server = axum::serve(listener, router).with_graceful_shutdown(async move {
ct.cancelled().await;
info!("mcp server cancelled");
});

tokio::spawn(
async move {
if let Err(e) = axum_server.await {
error!(error = %e, "mcp shutdown with error");
}
}
.instrument(tracing::info_span!("mcp-server", bind_address = %server.config.bind)),
);

server.with_service(move || server_handler.clone());
Ok(())
}

async fn serve_stdio(
server_handler: ApolloMcpServerHandler,
) -> Result<RunningService<RoleServer, ApolloMcpServerHandler>, ServerInitializeError<Error>> {
info!("Starting MCP server in stdio mode");
server_handler.serve(stdio()).await.inspect_err(|e| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would use inspect_err for a number of the if-lets above.

error!("serving error: {:?}", e);
})
}

/// Health check endpoint handler
async fn health_endpoint(
axum::extract::State(health_check): axum::extract::State<HealthCheck>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> Result<(StatusCode, Json<serde_json::Value>), StatusCode> {
let query = params.keys().next().map(|k| k.as_str());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Is it valid to have no params on a call to the health check endpoint?

let (health, status_code) = health_check.get_health_state(query);

trace!(?health, query = ?query, "health check");

Ok((status_code, Json(json!(health))))
}
Loading
Loading