Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions eio/handler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ type t = {
prompts: registered_prompt list;
completion_handler: completion_handler option;
task_handlers: task_handlers option;
mutable subscribed_uris: StringSet.t;
subscribed_uris: StringSet.t Atomic.t;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve independent subscription state per handler copy

Switching subscribed_uris to StringSet.t Atomic.t makes every record copy created via { s with ... } (for example through add_tool/add_resource/add_prompt) share the same mutable cell, whereas previously each copy had its own mutable field value. This means branching from a base handler into multiple handlers/servers will now leak resources/subscribe and resources/unsubscribe state across instances, so one instance’s subscriptions can unexpectedly affect another’s behavior.

Useful? React with 👍 / 👎.

}

let create ~name ~version ?instructions () =
{ name; version; instructions;
tools = []; resources = []; prompts = [];
completion_handler = None;
task_handlers = None;
subscribed_uris = StringSet.empty }
subscribed_uris = Atomic.make StringSet.empty }

(** H3 fix: Check for duplicate tool names before adding.
Previously, registering two tools with the same name would silently
Expand Down Expand Up @@ -129,7 +129,7 @@ let instructions s = s.instructions
let tools s = List.map (fun rt -> rt.tool) s.tools
let resources s = List.map (fun rr -> rr.resource) s.resources
let prompts s = List.map (fun rp -> rp.prompt) s.prompts
let subscribed_uris s = StringSet.elements s.subscribed_uris
let subscribed_uris s = StringSet.elements (Atomic.get s.subscribed_uris)

(* ── capabilities ─────────────────────────────────────── *)

Expand Down Expand Up @@ -341,12 +341,23 @@ let handle_prompts_get s ctx id params =

(* ── resources/subscribe + unsubscribe ────────────────── *)

(** Atomically update subscribed_uris via compare-and-set loop.
Lock-free: StringSet is immutable, so the CAS swaps a pointer. *)
let atomic_update_uris s f =
let rec loop () =
let old = Atomic.get s.subscribed_uris in
let updated = f old in
if Atomic.compare_and_set s.subscribed_uris old updated then ()
else loop ()
in
loop ()

let handle_resources_subscribe s id params =
match params with
| Some (`Assoc fields) ->
begin match List.assoc_opt "uri" fields with
| Some (`String uri) ->
s.subscribed_uris <- StringSet.add uri s.subscribed_uris;
atomic_update_uris s (StringSet.add uri);
Jsonrpc.make_response ~id ~result:(`Assoc [])
| _ ->
Jsonrpc.make_error ~id ~code:Error_codes.invalid_params
Expand All @@ -361,7 +372,7 @@ let handle_resources_unsubscribe s id params =
| Some (`Assoc fields) ->
begin match List.assoc_opt "uri" fields with
| Some (`String uri) ->
s.subscribed_uris <- StringSet.remove uri s.subscribed_uris;
atomic_update_uris s (StringSet.remove uri);
Jsonrpc.make_response ~id ~result:(`Assoc [])
| _ ->
Jsonrpc.make_error ~id ~code:Error_codes.invalid_params
Expand Down
165 changes: 165 additions & 0 deletions test/test_bug_fixes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,152 @@ let test_l2_validate_state_length_mismatch () =
Alcotest.(check bool) "different length"
false (Oauth_client.validate_state ~expected:"short" ~received:"longer_string")

(* ── C4: ensure_rng consolidation ──────────────── *)

let test_c4_ensure_rng_works_via_tls_helpers () =
(* C4 fix consolidated ensure_rng to Tls_helpers.ensure_rng only.
Oauth_client.generate_state calls ensure_rng_initialized which is
Tls_helpers.ensure_rng. If we can generate a state, the consolidated
initialization path works correctly. *)
let state = Oauth_client.generate_state () in
Alcotest.(check bool) "ensure_rng works via Tls_helpers"
true (String.length state > 0)

(* ── H1: session state transitions ───────────── *)

let test_h1_initialize_twice_fails () =
let session = Http_session.create () in
(match Http_session.state session with
| Http_session.Uninitialized -> ()
| _ -> Alcotest.fail "should start Uninitialized");
(match Http_session.initialize session with
| Ok _sid -> ()
| Error e -> Alcotest.fail e);
(match Http_session.state session with
| Http_session.Initializing -> ()
| _ -> Alcotest.fail "should be Initializing after initialize");
(* H1: second initialize must fail — concurrent call protection *)
(match Http_session.initialize session with
| Error _ -> ()
| Ok _ -> Alcotest.fail "H1 regression: double initialize should be rejected")

let test_h1_ready_requires_initializing () =
let session = Http_session.create () in
(* ready before initialize should fail *)
(match Http_session.ready session with
| Error _ -> ()
| Ok () -> Alcotest.fail "H1: ready on Uninitialized should fail");
ignore (Http_session.initialize session);
(match Http_session.ready session with
| Ok () -> ()
| Error e -> Alcotest.fail e);
(match Http_session.state session with
| Http_session.Ready -> ()
| _ -> Alcotest.fail "should be Ready after ready()")

let test_h1_close_from_any_state () =
let session = Http_session.create () in
Http_session.close session;
(match Http_session.state session with
| Http_session.Closed -> ()
| _ -> Alcotest.fail "should be Closed after close()")

(* ── H2: version negotiation with unknown version ── *)

let test_h2_unsupported_version_returns_latest () =
(* When a client requests a version the server doesn't know,
the server should fall back to latest (with a warning log). *)
let h = Mcp_protocol_eio.Handler.create ~name:"test" ~version:"1.0" () in
let params = `Assoc [
("protocolVersion", `String "9999-12-31");
("capabilities", `Assoc []);
("clientInfo", `Assoc [("name", `String "client"); ("version", `String "1")]);
] in
let req = Jsonrpc.make_request ~id:(Jsonrpc.Int 1) ~method_:"initialize" ~params () in
let log_ref = ref Logging.Warning in
let dummy_ctx : Mcp_protocol_eio.Handler.context = {
send_notification = (fun ~method_:_ ~params:_ -> Ok ());
send_log = (fun _ _ -> Ok ());
send_progress = (fun ~token:_ ~progress:_ ~total:_ -> Ok ());
request_sampling = (fun _ -> Error "n/a");
request_roots_list = (fun () -> Error "n/a");
request_elicitation = (fun _ -> Error "n/a");
} in
match Mcp_protocol_eio.Handler.dispatch h dummy_ctx log_ref req with
| Some (Response r) ->
(match r.result with
| `Assoc fields ->
(match List.assoc_opt "protocolVersion" fields with
| Some (`String v) ->
Alcotest.(check string) "falls back to latest" Version.latest v
| _ -> Alcotest.fail "missing protocolVersion in response")
| _ -> Alcotest.fail "expected assoc result")
| _ -> Alcotest.fail "H2: unsupported version should still produce a Response"

let test_h2_supported_version_negotiated () =
(* When a client requests a supported version, negotiate returns it. *)
let h = Mcp_protocol_eio.Handler.create ~name:"test" ~version:"1.0" () in
let params = `Assoc [
("protocolVersion", `String "2024-11-05");
("capabilities", `Assoc []);
("clientInfo", `Assoc [("name", `String "client"); ("version", `String "1")]);
] in
let req = Jsonrpc.make_request ~id:(Jsonrpc.Int 1) ~method_:"initialize" ~params () in
let log_ref = ref Logging.Warning in
let dummy_ctx : Mcp_protocol_eio.Handler.context = {
send_notification = (fun ~method_:_ ~params:_ -> Ok ());
send_log = (fun _ _ -> Ok ());
send_progress = (fun ~token:_ ~progress:_ ~total:_ -> Ok ());
request_sampling = (fun _ -> Error "n/a");
request_roots_list = (fun () -> Error "n/a");
request_elicitation = (fun _ -> Error "n/a");
} in
match Mcp_protocol_eio.Handler.dispatch h dummy_ctx log_ref req with
| Some (Response r) ->
(match r.result with
| `Assoc fields ->
(match List.assoc_opt "protocolVersion" fields with
| Some (`String v) ->
Alcotest.(check string) "negotiated version" "2024-11-05" v
| _ -> Alcotest.fail "missing protocolVersion")
| _ -> Alcotest.fail "expected assoc result")
| _ -> Alcotest.fail "expected Response"

(* ── L1: max response size constant ──────────── *)

let test_l1_default_max_response_size () =
Alcotest.(check int) "default_max_response_size is 1MB"
(1024 * 1024) Oauth_client.default_max_response_size

(* ── M1: pre-init session validation ──────────── *)

let test_m1_pre_init_accepts_any_header () =
let session = Http_session.create () in
(* Before initialize, validate should accept any header value *)
(match Http_session.validate session None with
| Ok () -> ()
| Error _ -> Alcotest.fail "M1: pre-init should accept no header");
(match Http_session.validate session (Some "anything") with
| Ok () -> ()
| Error _ -> Alcotest.fail "M1: pre-init should accept any header value")

let test_m1_post_init_requires_matching_header () =
let session = Http_session.create () in
let sid = match Http_session.initialize session with
| Ok sid -> sid
| Error e -> Alcotest.fail e
in
(* After initialize, must match session id *)
(match Http_session.validate session (Some sid) with
| Ok () -> ()
| Error _ -> Alcotest.fail "matching header should pass");
(match Http_session.validate session (Some "wrong-id") with
| Error `Not_found -> ()
| _ -> Alcotest.fail "wrong header should return Not_found");
(match Http_session.validate session None with
| Error (`Bad_request _) -> ()
| _ -> Alcotest.fail "missing header should return Bad_request")

(* ── L3: scope checking with StringSet ───────── *)

let test_l3_scope_check_subset () =
Expand Down Expand Up @@ -294,6 +440,25 @@ let () =
Alcotest.test_case "validate mismatch" `Quick test_l2_validate_state_mismatch;
Alcotest.test_case "validate length mismatch" `Quick test_l2_validate_state_length_mismatch;
];
"C4_ensure_rng", [
Alcotest.test_case "consolidated via Tls_helpers" `Quick test_c4_ensure_rng_works_via_tls_helpers;
];
"H1_session_state", [
Alcotest.test_case "double initialize rejected" `Quick test_h1_initialize_twice_fails;
Alcotest.test_case "ready requires Initializing" `Quick test_h1_ready_requires_initializing;
Alcotest.test_case "close from any state" `Quick test_h1_close_from_any_state;
];
"H2_version_negotiation", [
Alcotest.test_case "unsupported version -> latest" `Quick test_h2_unsupported_version_returns_latest;
Alcotest.test_case "supported version negotiated" `Quick test_h2_supported_version_negotiated;
];
"L1_max_response_size", [
Alcotest.test_case "default is 1MB" `Quick test_l1_default_max_response_size;
];
"M1_pre_init_validation", [
Alcotest.test_case "pre-init accepts any" `Quick test_m1_pre_init_accepts_any_header;
Alcotest.test_case "post-init requires match" `Quick test_m1_post_init_requires_matching_header;
];
"L3_scope_check", [
Alcotest.test_case "subset passes" `Quick test_l3_scope_check_subset;
Alcotest.test_case "missing fails" `Quick test_l3_scope_check_missing;
Expand Down
Loading