diff --git a/eio/handler.ml b/eio/handler.ml index 649981b..321fbcc 100644 --- a/eio/handler.ml +++ b/eio/handler.ml @@ -56,7 +56,7 @@ type t = { prompts: registered_prompt list; completion_handler: completion_handler option; task_handlers: task_handlers option; - mutable subscribed_uris: StringSet.t; + subscribed_uris: StringSet.t Atomic.t; } let create ~name ~version ?instructions () = @@ -64,7 +64,7 @@ let create ~name ~version ?instructions () = tools = []; resources = []; prompts = []; completion_handler = None; task_handlers = None; - subscribed_uris = StringSet.empty } + subscribed_uris = Atomic.make StringSet.empty } (** H3 fix: Check for duplicate tool names before adding. Previously, registering two tools with the same name would silently @@ -129,7 +129,7 @@ let instructions s = s.instructions let tools s = List.map (fun rt -> rt.tool) s.tools let resources s = List.map (fun rr -> rr.resource) s.resources let prompts s = List.map (fun rp -> rp.prompt) s.prompts -let subscribed_uris s = StringSet.elements s.subscribed_uris +let subscribed_uris s = StringSet.elements (Atomic.get s.subscribed_uris) (* ── capabilities ─────────────────────────────────────── *) @@ -341,12 +341,23 @@ let handle_prompts_get s ctx id params = (* ── resources/subscribe + unsubscribe ────────────────── *) +(** Atomically update subscribed_uris via compare-and-set loop. + Lock-free: StringSet is immutable, so the CAS swaps a pointer. *) +let atomic_update_uris s f = + let rec loop () = + let old = Atomic.get s.subscribed_uris in + let updated = f old in + if Atomic.compare_and_set s.subscribed_uris old updated then () + else loop () + in + loop () + let handle_resources_subscribe s id params = match params with | Some (`Assoc fields) -> begin match List.assoc_opt "uri" fields with | Some (`String uri) -> - s.subscribed_uris <- StringSet.add uri s.subscribed_uris; + atomic_update_uris s (StringSet.add uri); Jsonrpc.make_response ~id ~result:(`Assoc []) | _ -> Jsonrpc.make_error ~id ~code:Error_codes.invalid_params @@ -361,7 +372,7 @@ let handle_resources_unsubscribe s id params = | Some (`Assoc fields) -> begin match List.assoc_opt "uri" fields with | Some (`String uri) -> - s.subscribed_uris <- StringSet.remove uri s.subscribed_uris; + atomic_update_uris s (StringSet.remove uri); Jsonrpc.make_response ~id ~result:(`Assoc []) | _ -> Jsonrpc.make_error ~id ~code:Error_codes.invalid_params diff --git a/test/test_bug_fixes.ml b/test/test_bug_fixes.ml index a14ff6a..afc732a 100644 --- a/test/test_bug_fixes.ml +++ b/test/test_bug_fixes.ml @@ -236,6 +236,152 @@ let test_l2_validate_state_length_mismatch () = Alcotest.(check bool) "different length" false (Oauth_client.validate_state ~expected:"short" ~received:"longer_string") +(* ── C4: ensure_rng consolidation ──────────────── *) + +let test_c4_ensure_rng_works_via_tls_helpers () = + (* C4 fix consolidated ensure_rng to Tls_helpers.ensure_rng only. + Oauth_client.generate_state calls ensure_rng_initialized which is + Tls_helpers.ensure_rng. If we can generate a state, the consolidated + initialization path works correctly. *) + let state = Oauth_client.generate_state () in + Alcotest.(check bool) "ensure_rng works via Tls_helpers" + true (String.length state > 0) + +(* ── H1: session state transitions ───────────── *) + +let test_h1_initialize_twice_fails () = + let session = Http_session.create () in + (match Http_session.state session with + | Http_session.Uninitialized -> () + | _ -> Alcotest.fail "should start Uninitialized"); + (match Http_session.initialize session with + | Ok _sid -> () + | Error e -> Alcotest.fail e); + (match Http_session.state session with + | Http_session.Initializing -> () + | _ -> Alcotest.fail "should be Initializing after initialize"); + (* H1: second initialize must fail — concurrent call protection *) + (match Http_session.initialize session with + | Error _ -> () + | Ok _ -> Alcotest.fail "H1 regression: double initialize should be rejected") + +let test_h1_ready_requires_initializing () = + let session = Http_session.create () in + (* ready before initialize should fail *) + (match Http_session.ready session with + | Error _ -> () + | Ok () -> Alcotest.fail "H1: ready on Uninitialized should fail"); + ignore (Http_session.initialize session); + (match Http_session.ready session with + | Ok () -> () + | Error e -> Alcotest.fail e); + (match Http_session.state session with + | Http_session.Ready -> () + | _ -> Alcotest.fail "should be Ready after ready()") + +let test_h1_close_from_any_state () = + let session = Http_session.create () in + Http_session.close session; + (match Http_session.state session with + | Http_session.Closed -> () + | _ -> Alcotest.fail "should be Closed after close()") + +(* ── H2: version negotiation with unknown version ── *) + +let test_h2_unsupported_version_returns_latest () = + (* When a client requests a version the server doesn't know, + the server should fall back to latest (with a warning log). *) + let h = Mcp_protocol_eio.Handler.create ~name:"test" ~version:"1.0" () in + let params = `Assoc [ + ("protocolVersion", `String "9999-12-31"); + ("capabilities", `Assoc []); + ("clientInfo", `Assoc [("name", `String "client"); ("version", `String "1")]); + ] in + let req = Jsonrpc.make_request ~id:(Jsonrpc.Int 1) ~method_:"initialize" ~params () in + let log_ref = ref Logging.Warning in + let dummy_ctx : Mcp_protocol_eio.Handler.context = { + send_notification = (fun ~method_:_ ~params:_ -> Ok ()); + send_log = (fun _ _ -> Ok ()); + send_progress = (fun ~token:_ ~progress:_ ~total:_ -> Ok ()); + request_sampling = (fun _ -> Error "n/a"); + request_roots_list = (fun () -> Error "n/a"); + request_elicitation = (fun _ -> Error "n/a"); + } in + match Mcp_protocol_eio.Handler.dispatch h dummy_ctx log_ref req with + | Some (Response r) -> + (match r.result with + | `Assoc fields -> + (match List.assoc_opt "protocolVersion" fields with + | Some (`String v) -> + Alcotest.(check string) "falls back to latest" Version.latest v + | _ -> Alcotest.fail "missing protocolVersion in response") + | _ -> Alcotest.fail "expected assoc result") + | _ -> Alcotest.fail "H2: unsupported version should still produce a Response" + +let test_h2_supported_version_negotiated () = + (* When a client requests a supported version, negotiate returns it. *) + let h = Mcp_protocol_eio.Handler.create ~name:"test" ~version:"1.0" () in + let params = `Assoc [ + ("protocolVersion", `String "2024-11-05"); + ("capabilities", `Assoc []); + ("clientInfo", `Assoc [("name", `String "client"); ("version", `String "1")]); + ] in + let req = Jsonrpc.make_request ~id:(Jsonrpc.Int 1) ~method_:"initialize" ~params () in + let log_ref = ref Logging.Warning in + let dummy_ctx : Mcp_protocol_eio.Handler.context = { + send_notification = (fun ~method_:_ ~params:_ -> Ok ()); + send_log = (fun _ _ -> Ok ()); + send_progress = (fun ~token:_ ~progress:_ ~total:_ -> Ok ()); + request_sampling = (fun _ -> Error "n/a"); + request_roots_list = (fun () -> Error "n/a"); + request_elicitation = (fun _ -> Error "n/a"); + } in + match Mcp_protocol_eio.Handler.dispatch h dummy_ctx log_ref req with + | Some (Response r) -> + (match r.result with + | `Assoc fields -> + (match List.assoc_opt "protocolVersion" fields with + | Some (`String v) -> + Alcotest.(check string) "negotiated version" "2024-11-05" v + | _ -> Alcotest.fail "missing protocolVersion") + | _ -> Alcotest.fail "expected assoc result") + | _ -> Alcotest.fail "expected Response" + +(* ── L1: max response size constant ──────────── *) + +let test_l1_default_max_response_size () = + Alcotest.(check int) "default_max_response_size is 1MB" + (1024 * 1024) Oauth_client.default_max_response_size + +(* ── M1: pre-init session validation ──────────── *) + +let test_m1_pre_init_accepts_any_header () = + let session = Http_session.create () in + (* Before initialize, validate should accept any header value *) + (match Http_session.validate session None with + | Ok () -> () + | Error _ -> Alcotest.fail "M1: pre-init should accept no header"); + (match Http_session.validate session (Some "anything") with + | Ok () -> () + | Error _ -> Alcotest.fail "M1: pre-init should accept any header value") + +let test_m1_post_init_requires_matching_header () = + let session = Http_session.create () in + let sid = match Http_session.initialize session with + | Ok sid -> sid + | Error e -> Alcotest.fail e + in + (* After initialize, must match session id *) + (match Http_session.validate session (Some sid) with + | Ok () -> () + | Error _ -> Alcotest.fail "matching header should pass"); + (match Http_session.validate session (Some "wrong-id") with + | Error `Not_found -> () + | _ -> Alcotest.fail "wrong header should return Not_found"); + (match Http_session.validate session None with + | Error (`Bad_request _) -> () + | _ -> Alcotest.fail "missing header should return Bad_request") + (* ── L3: scope checking with StringSet ───────── *) let test_l3_scope_check_subset () = @@ -294,6 +440,25 @@ let () = Alcotest.test_case "validate mismatch" `Quick test_l2_validate_state_mismatch; Alcotest.test_case "validate length mismatch" `Quick test_l2_validate_state_length_mismatch; ]; + "C4_ensure_rng", [ + Alcotest.test_case "consolidated via Tls_helpers" `Quick test_c4_ensure_rng_works_via_tls_helpers; + ]; + "H1_session_state", [ + Alcotest.test_case "double initialize rejected" `Quick test_h1_initialize_twice_fails; + Alcotest.test_case "ready requires Initializing" `Quick test_h1_ready_requires_initializing; + Alcotest.test_case "close from any state" `Quick test_h1_close_from_any_state; + ]; + "H2_version_negotiation", [ + Alcotest.test_case "unsupported version -> latest" `Quick test_h2_unsupported_version_returns_latest; + Alcotest.test_case "supported version negotiated" `Quick test_h2_supported_version_negotiated; + ]; + "L1_max_response_size", [ + Alcotest.test_case "default is 1MB" `Quick test_l1_default_max_response_size; + ]; + "M1_pre_init_validation", [ + Alcotest.test_case "pre-init accepts any" `Quick test_m1_pre_init_accepts_any_header; + Alcotest.test_case "post-init requires match" `Quick test_m1_post_init_requires_matching_header; + ]; "L3_scope_check", [ Alcotest.test_case "subset passes" `Quick test_l3_scope_check_subset; Alcotest.test_case "missing fails" `Quick test_l3_scope_check_missing;