diff --git a/.config/dotnet-tools.json b/.config/dotnet-tools.json index 13f3a42..8ba49f1 100644 --- a/.config/dotnet-tools.json +++ b/.config/dotnet-tools.json @@ -17,7 +17,7 @@ "rollForward": false }, "fantomas": { - "version": "7.0.0", + "version": "7.0.3", "commands": [ "fantomas" ], diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..90ec38b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,63 @@ +## Build, Test & Lint Commands + +- **Build**: `dotnet fake build -t Build` (Release configuration) +- **Format Check**: `dotnet fake build -t CheckFormat` (validates Fantomas formatting) +- **Format**: `dotnet fake build -t Format` (applies Fantomas formatting) +- **All Tests**: `dotnet fake build -t RunTests` (builds + starts test server + runs all tests) +- **Unit Tests Only**: `dotnet build && dotnet tests/SwaggerProvider.Tests/bin/Release/net9.0/SwaggerProvider.Tests.dll` +- **Provider Tests (Integration)**: + 1. Build test server: `dotnet build tests/Swashbuckle.WebApi.Server/Swashbuckle.WebApi.Server.fsproj -c Release` + 2. Start server in background: `dotnet tests/Swashbuckle.WebApi.Server/bin/Release/net9.0/Swashbuckle.WebApi.Server.dll` + 3. Build tests: `dotnet build SwaggerProvider.TestsAndDocs.sln -c Release` + 4. Run tests: `dotnet tests/SwaggerProvider.ProviderTests/bin/Release/net9.0/SwaggerProvider.ProviderTests.dll` +- **Single Test**: Run via xunit runner: `dotnet [assembly] [filter]` + +## Code Style Guidelines + +**Language**: F# (net9.0 target framework) + +**Imports & Namespaces**: + +- `namespace [Module]` at file start; no `open` statements at module level +- Use `module [Name]` for nested modules +- Open dependencies after namespace declaration (e.g., `open Xunit`, `open FsUnitTyped`) +- Fully qualify internal modules: `SwaggerProvider.Internal.v2.Parser`, `SwaggerProvider.Internal.v3.Compilers` + +**Formatting** (via Fantomas, EditorConfig enforced): + +- 4-space indents, max 150 char line length +- `fsharp_max_function_binding_width=10`, `fsharp_max_infix_operator_expression=70` +- No space before parameter/lowercase invocation +- Multiline block brackets on same column, Stroustrup style enabled +- Bar before discriminated union declarations, max 3 blank lines + +**Naming Conventions**: + +- PascalCase for classes, types, modules, public members +- camelCase for local/private bindings, parameters +- Suffix test functions with `Tests` or use attributes like `[]`, `[]` + +**Type Annotations**: + +- Explicit return types for public functions (recommended) +- Use type inference for local bindings when obvious +- Generic type parameters: `'a`, `'b` (single quote prefix) + +**Error Handling**: + +- Use `Result<'T, 'Error>` or `Option<'T>` for fallible operations +- `failwith` or `failwithf` for errors in type providers and compilers +- Task-based async for I/O: `task { }` expressions in tests +- Match failures with `| _ -> ...` or pattern guards with `when` + +**File Organization**: + +- Tests use Xunit attributes: `[]`, `[]`, `[]` +- Design-time providers in `src/SwaggerProvider.DesignTime/`, runtime in `src/SwaggerProvider.Runtime/` +- Test schemas organized by OpenAPI version: `tests/.../Schemas/{v2,v3}/` + +## Key Patterns + +- Type Providers use `ProvidedApiClientBase` and compiler pipeline (DefinitionCompiler, OperationCompiler) +- SSRF protection enabled by default; disable with `SsrfProtection=false` static parameter +- Target net9.0; use implicit async/await (task expressions) diff --git a/src/SwaggerProvider.DesignTime/Provider.OpenApiClient.fs b/src/SwaggerProvider.DesignTime/Provider.OpenApiClient.fs index 8611c97..bf08144 100644 --- a/src/SwaggerProvider.DesignTime/Provider.OpenApiClient.fs +++ b/src/SwaggerProvider.DesignTime/Provider.OpenApiClient.fs @@ -8,7 +8,7 @@ open Swagger open SwaggerProvider.Internal open SwaggerProvider.Internal.v3.Compilers -module Cache = +module OpenApiCache = let providedTypes = Caching.createInMemoryCache(TimeSpan.FromSeconds 30.0) /// The Open API Provider. @@ -51,10 +51,7 @@ type public OpenApiClientTypeProvider(cfg: TypeProviderConfig) as this = t.DefineStaticParameters( staticParams, fun typeName args -> - let schemaPath = - let schemaPathRaw = unbox args.[0] - SchemaReader.getAbsolutePath cfg.ResolutionFolder schemaPathRaw - + let schemaPathRaw = unbox args.[0] let ignoreOperationId = unbox args.[1] let ignoreControllerPrefix = unbox args.[2] let preferNullable = unbox args.[3] @@ -62,14 +59,13 @@ type public OpenApiClientTypeProvider(cfg: TypeProviderConfig) as this = let ssrfProtection = unbox args.[5] let cacheKey = - (schemaPath, ignoreOperationId, ignoreControllerPrefix, preferNullable, preferAsync, ssrfProtection) + (schemaPathRaw, ignoreOperationId, ignoreControllerPrefix, preferNullable, preferAsync, ssrfProtection) |> sprintf "%A" - let addCache() = lazy let schemaData = - SchemaReader.readSchemaPath (not ssrfProtection) "" schemaPath + SchemaReader.readSchemaPath (not ssrfProtection) "" cfg.ResolutionFolder schemaPathRaw |> Async.RunSynchronously let openApiReader = Microsoft.OpenApi.Readers.OpenApiStringReader() @@ -96,18 +92,18 @@ type public OpenApiClientTypeProvider(cfg: TypeProviderConfig) as this = let ty = ProvidedTypeDefinition(tempAsm, ns, typeName, Some typeof, isErased = false, hideObjectMethods = true) - ty.AddXmlDoc("OpenAPI Provider for " + schemaPath) + ty.AddXmlDoc("OpenAPI Provider for " + schemaPathRaw) ty.AddMembers tys tempAsm.AddTypes [ ty ] ty try - Cache.providedTypes.GetOrAdd(cacheKey, addCache).Value + OpenApiCache.providedTypes.GetOrAdd(cacheKey, addCache).Value with _ -> - Cache.providedTypes.Remove(cacheKey) |> ignore + OpenApiCache.providedTypes.Remove(cacheKey) |> ignore - Cache.providedTypes.GetOrAdd(cacheKey, addCache).Value + OpenApiCache.providedTypes.GetOrAdd(cacheKey, addCache).Value ) t diff --git a/src/SwaggerProvider.DesignTime/Provider.SwaggerClient.fs b/src/SwaggerProvider.DesignTime/Provider.SwaggerClient.fs index 8e44ffd..0102e01 100644 --- a/src/SwaggerProvider.DesignTime/Provider.SwaggerClient.fs +++ b/src/SwaggerProvider.DesignTime/Provider.SwaggerClient.fs @@ -9,6 +9,9 @@ open SwaggerProvider.Internal open SwaggerProvider.Internal.v2.Parser open SwaggerProvider.Internal.v2.Compilers +module SwaggerCache = + let providedTypes = Caching.createInMemoryCache(TimeSpan.FromSeconds 30.0) + /// The Swagger Type Provider. [] type public SwaggerTypeProvider(cfg: TypeProviderConfig) as this = @@ -51,10 +54,7 @@ type public SwaggerTypeProvider(cfg: TypeProviderConfig) as this = t.DefineStaticParameters( staticParams, fun typeName args -> - let schemaPath = - let schemaPathRaw = unbox args.[0] - SchemaReader.getAbsolutePath cfg.ResolutionFolder schemaPathRaw - + let schemaPathRaw = unbox args.[0] let headersStr = unbox args.[1] let ignoreOperationId = unbox args.[2] let ignoreControllerPrefix = unbox args.[3] @@ -63,13 +63,13 @@ type public SwaggerTypeProvider(cfg: TypeProviderConfig) as this = let ssrfProtection = unbox args.[6] let cacheKey = - (schemaPath, headersStr, ignoreOperationId, ignoreControllerPrefix, preferNullable, preferAsync, ssrfProtection) + (schemaPathRaw, headersStr, ignoreOperationId, ignoreControllerPrefix, preferNullable, preferAsync, ssrfProtection) |> sprintf "%A" let addCache() = lazy let schemaData = - SchemaReader.readSchemaPath (not ssrfProtection) headersStr schemaPath + SchemaReader.readSchemaPath (not ssrfProtection) headersStr cfg.ResolutionFolder schemaPathRaw |> Async.RunSynchronously let schema = SwaggerParser.parseSchema schemaData @@ -87,13 +87,13 @@ type public SwaggerTypeProvider(cfg: TypeProviderConfig) as this = let ty = ProvidedTypeDefinition(tempAsm, ns, typeName, Some typeof, isErased = false, hideObjectMethods = true) - ty.AddXmlDoc("Swagger Provider for " + schemaPath) + ty.AddXmlDoc("Swagger Provider for " + schemaPathRaw) ty.AddMembers tys tempAsm.AddTypes [ ty ] ty - Cache.providedTypes.GetOrAdd(cacheKey, addCache).Value + SwaggerCache.providedTypes.GetOrAdd(cacheKey, addCache).Value ) t diff --git a/src/SwaggerProvider.DesignTime/Utils.fs b/src/SwaggerProvider.DesignTime/Utils.fs index ee31712..b375a58 100644 --- a/src/SwaggerProvider.DesignTime/Utils.fs +++ b/src/SwaggerProvider.DesignTime/Utils.fs @@ -7,6 +7,9 @@ module SchemaReader = open System.Net.Http let getAbsolutePath (resolutionFolder: string) (schemaPathRaw: string) = + if String.IsNullOrWhiteSpace(schemaPathRaw) then + invalidArg "schemaPathRaw" "The schema path cannot be null or empty." + let uri = Uri(schemaPathRaw, UriKind.RelativeOrAbsolute) if uri.IsAbsoluteUri then @@ -33,27 +36,60 @@ module SchemaReader = let isIp, ipAddr = IPAddress.TryParse host if isIp then - // Loopback - if IPAddress.IsLoopback ipAddr || ipAddr.ToString() = "0.0.0.0" then - failwithf "Cannot fetch schemas from localhost/loopback addresses: %s (set SsrfProtection=false for development)" host - // Private IPv4 ranges - let bytes = ipAddr.GetAddressBytes() - - let isPrivate = - ipAddr.AddressFamily = Sockets.AddressFamily.InterNetwork - && match bytes with - | [| 10uy; _; _; _ |] -> true // 10.0.0.0/8 - | [| 172uy; b1; _; _ |] when b1 >= 16uy && b1 <= 31uy -> true // 172.16.0.0/12 - | [| 192uy; 168uy; _; _ |] -> true // 192.168.0.0/16 - | [| 169uy; 254uy; _; _ |] -> true // Link-local 169.254.0.0/16 - | _ -> false - - if isPrivate then - failwithf "Cannot fetch schemas from private or link-local IP addresses: %s (set SsrfProtection=false for development)" host - else if - // Block localhost by name - host = "localhost" - then + // Check address family first to apply family-specific rules + match ipAddr.AddressFamily with + | Sockets.AddressFamily.InterNetwork -> + // IPv4 validation + let bytes = ipAddr.GetAddressBytes() + + // Check for IPv4 loopback or unspecified address + if IPAddress.IsLoopback ipAddr || ipAddr.ToString() = "0.0.0.0" then + failwithf "Cannot fetch schemas from localhost/loopback addresses: %s (set SsrfProtection=false for development)" host + + // Check for IPv4 private ranges + let isPrivateIPv4 = + match bytes with + // 10.0.0.0/8 + | [| 10uy; _; _; _ |] -> true + // 172.16.0.0/12 + | [| 172uy; secondByte; _; _ |] when secondByte >= 16uy && secondByte <= 31uy -> true + // 192.168.0.0/16 + | [| 192uy; 168uy; _; _ |] -> true + // Link-local 169.254.0.0/16 + | [| 169uy; 254uy; _; _ |] -> true + | _ -> false + + if isPrivateIPv4 then + failwithf "Cannot fetch schemas from private or link-local IP addresses: %s (set SsrfProtection=false for development)" host + + | Sockets.AddressFamily.InterNetworkV6 -> + // IPv6 validation + let bytes = ipAddr.GetAddressBytes() + + // Check for IPv6 private or reserved ranges + let isPrivateIPv6 = + match bytes with + // Loopback (::1) + | [| 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 1uy |] -> true + // Unspecified address (::) + | [| 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy; 0uy |] -> true + // Link-local (fe80::/10) - first byte 0xFE, second byte 0x80-0xBF + | [| 0xFEuy; secondByte; _; _; _; _; _; _; _; _; _; _; _; _; _; _ |] when secondByte >= 0x80uy && secondByte <= 0xBFuy -> true + // Unique Local Unicast (fc00::/7) - first byte 0xFC or 0xFD + | [| 0xFCuy; _; _; _; _; _; _; _; _; _; _; _; _; _; _; _ |] -> true + | [| 0xFDuy; _; _; _; _; _; _; _; _; _; _; _; _; _; _; _ |] -> true + // Multicast (ff00::/8) - first byte 0xFF + | [| 0xFFuy; _; _; _; _; _; _; _; _; _; _; _; _; _; _; _ |] -> true + | _ -> false + + if isPrivateIPv6 then + failwithf "Cannot fetch schemas from private or loopback IPv6 addresses: %s (set SsrfProtection=false for development)" host + + | _ -> + // Unsupported address family + failwithf "Cannot fetch schemas from unsupported IP address type: %s (set SsrfProtection=false for development)" host + // Block localhost by hostname + else if host = "localhost" then failwithf "Cannot fetch schemas from localhost/loopback addresses: %s (set SsrfProtection=false for development)" host let validateContentType (ignoreSsrfProtection: bool) (contentType: Headers.MediaTypeHeaderValue) = @@ -91,119 +127,158 @@ module SchemaReader = "Invalid Content-Type for schema: %s. Expected JSON or YAML content types only. This protects against SSRF attacks. Set SsrfProtection=false to disable this validation." mediaType - let readSchemaPath (ignoreSsrfProtection: bool) (headersStr: string) (schemaPathRaw: string) = + let readSchemaPath (ignoreSsrfProtection: bool) (headersStr: string) (resolutionFolder: string) (schemaPathRaw: string) = async { - let uri = Uri schemaPathRaw - - match uri.Scheme with - | "https" -> - // Validate URL to prevent SSRF (unless explicitly disabled) - validateSchemaUrl ignoreSsrfProtection uri - - let headers = - headersStr.Split '|' - |> Seq.choose(fun x -> - let pair = x.Split '=' - - if (pair.Length = 2) then Some(pair[0], pair[1]) else None) - - let request = new HttpRequestMessage(HttpMethod.Get, schemaPathRaw) - - for name, value in headers do - request.Headers.TryAddWithoutValidation(name, value) |> ignore - - // SECURITY: Remove UseDefaultCredentials to prevent credential leakage (always enforced) - use handler = new HttpClientHandler(UseDefaultCredentials = false) - use client = new HttpClient(handler, Timeout = System.TimeSpan.FromSeconds 60.0) - - let! res = - async { - let! response = client.SendAsync request |> Async.AwaitTask - - // Validate Content-Type to ensure we're parsing the correct format - validateContentType ignoreSsrfProtection response.Content.Headers.ContentType - - return! response.Content.ReadAsStringAsync() |> Async.AwaitTask - } - |> Async.Catch - - match res with - | Choice1Of2 x -> return x - | Choice2Of2(:? Swagger.OpenApiException as ex) when not <| isNull ex.Content -> - let content = - ex.Content.ReadAsStringAsync() - |> Async.AwaitTask - |> Async.RunSynchronously - - if String.IsNullOrEmpty content then - return ex.Reraise() + // Resolve the schema path to absolute path first + let resolvedPath = getAbsolutePath resolutionFolder schemaPathRaw + + // Check if this is a local file path (not a remote URL) + // First try to treat it as a local file path (absolute or relative) + let possibleFilePath = + try + if Path.IsPathRooted resolvedPath then + // Already an absolute path + if File.Exists resolvedPath then Some resolvedPath else None else - return content - | Choice2Of2(:? WebException as wex) when not <| isNull wex.Response -> - use stream = wex.Response.GetResponseStream() - use reader = new StreamReader(stream) - let err = reader.ReadToEnd() - - return - if String.IsNullOrEmpty err then - wex.Reraise() - else - err.ToString() - | Choice2Of2 e -> return failwith(e.ToString()) - | "http" -> - // HTTP is allowed only when SSRF protection is explicitly disabled (development/testing mode) - if not ignoreSsrfProtection then - return - failwithf - "HTTP URLs are not supported for security reasons. Use HTTPS or set SsrfProtection=false for development: %s" - schemaPathRaw + // Try to resolve relative paths (e.g., paths with ../ or from __SOURCE_DIRECTORY__) + let resolved = Path.GetFullPath resolvedPath + if File.Exists resolved then Some resolved else None + with _ -> + None + + match possibleFilePath with + | Some filePath -> + // Handle local file - read from disk + try + return File.ReadAllText filePath + with + | :? FileNotFoundException -> return failwithf "Schema file not found: %s" filePath + | ex -> return failwithf "Error reading schema file '%s': %s" filePath ex.Message + | None -> + // Handle as remote URL (HTTP/HTTPS) + let checkUri = Uri(resolvedPath, UriKind.RelativeOrAbsolute) + // Only treat truly local paths as local files (no scheme or relative paths) + // Reject file:// scheme as unsupported to prevent SSRF attacks + let isLocalFile = not checkUri.IsAbsoluteUri + + if isLocalFile then + // If we reach here with a local file that wasn't found, report the error + return failwithf "Schema file not found: %s" resolvedPath else - // Development mode: allow HTTP - validateSchemaUrl ignoreSsrfProtection uri - - let headers = - headersStr.Split '|' - |> Seq.choose(fun x -> - let pair = x.Split '=' - if (pair.Length = 2) then Some(pair[0], pair[1]) else None) - - let request = new HttpRequestMessage(HttpMethod.Get, schemaPathRaw) - - for name, value in headers do - request.Headers.TryAddWithoutValidation(name, value) |> ignore - - use handler = new HttpClientHandler(UseDefaultCredentials = false) - use client = new HttpClient(handler, Timeout = System.TimeSpan.FromSeconds 60.0) - - let! res = - async { - let! response = client.SendAsync(request) |> Async.AwaitTask - - // Validate Content-Type to ensure we're parsing the correct format - validateContentType ignoreSsrfProtection response.Content.Headers.ContentType - - return! response.Content.ReadAsStringAsync() |> Async.AwaitTask - } - |> Async.Catch - - match res with - | Choice1Of2 x -> return x - | Choice2Of2(:? WebException as wex) when not <| isNull wex.Response -> - use stream = wex.Response.GetResponseStream() - use reader = new StreamReader(stream) - let err = reader.ReadToEnd() - - return - if String.IsNullOrEmpty err then - wex.Reraise() + // Handle remote URL (HTTP/HTTPS) + let uri = Uri resolvedPath + + match uri.Scheme with + | "https" -> + // Validate URL to prevent SSRF (unless explicitly disabled) + validateSchemaUrl ignoreSsrfProtection uri + + let headers = + headersStr.Split '|' + |> Seq.choose(fun x -> + let pair = x.Split '=' + if (pair.Length = 2) then Some(pair[0], pair[1]) else None) + + let request = new HttpRequestMessage(HttpMethod.Get, resolvedPath) + + for name, value in headers do + request.Headers.TryAddWithoutValidation(name, value) |> ignore + + // SECURITY: Remove UseDefaultCredentials to prevent credential leakage (always enforced) + use handler = new HttpClientHandler(UseDefaultCredentials = false) + use client = new HttpClient(handler, Timeout = TimeSpan.FromSeconds 60.0) + + let! res = + async { + let! response = client.SendAsync request |> Async.AwaitTask + + // Validate Content-Type to ensure we're parsing the correct format + validateContentType ignoreSsrfProtection response.Content.Headers.ContentType + + return! response.Content.ReadAsStringAsync() |> Async.AwaitTask + } + |> Async.Catch + + match res with + | Choice1Of2 x -> return x + | Choice2Of2(:? Swagger.OpenApiException as ex) when not <| isNull ex.Content -> + let content = + ex.Content.ReadAsStringAsync() + |> Async.AwaitTask + |> Async.RunSynchronously + + if String.IsNullOrEmpty content then + return ex.Reraise() else - err.ToString() - | Choice2Of2 e -> return failwith(e.ToString()) - | _ -> - let request = WebRequest.Create(schemaPathRaw) - use! response = request.GetResponseAsync() |> Async.AwaitTask - use sr = new StreamReader(response.GetResponseStream()) - return! sr.ReadToEndAsync() |> Async.AwaitTask + return content + | Choice2Of2(:? WebException as wex) when not <| isNull wex.Response -> + use stream = wex.Response.GetResponseStream() + use reader = new StreamReader(stream) + let err = reader.ReadToEnd() + + return + if String.IsNullOrEmpty err then + wex.Reraise() + else + err.ToString() + | Choice2Of2 e -> return failwith(e.ToString()) + + | "http" -> + // HTTP is allowed only when SSRF protection is explicitly disabled (development/testing mode) + if not ignoreSsrfProtection then + return + failwithf + "HTTP URLs are not supported for security reasons. Use HTTPS or set SsrfProtection=false for development: %s" + resolvedPath + else + // Development mode: allow HTTP + validateSchemaUrl ignoreSsrfProtection uri + + let headers = + headersStr.Split '|' + |> Seq.choose(fun x -> + let pair = x.Split '=' + if (pair.Length = 2) then Some(pair[0], pair[1]) else None) + + let request = new HttpRequestMessage(HttpMethod.Get, resolvedPath) + + for name, value in headers do + request.Headers.TryAddWithoutValidation(name, value) |> ignore + + use handler = new HttpClientHandler(UseDefaultCredentials = false) + use client = new HttpClient(handler, Timeout = TimeSpan.FromSeconds 60.0) + + let! res = + async { + let! response = client.SendAsync(request) |> Async.AwaitTask + + // Validate Content-Type to ensure we're parsing the correct format + validateContentType ignoreSsrfProtection response.Content.Headers.ContentType + + return! response.Content.ReadAsStringAsync() |> Async.AwaitTask + } + |> Async.Catch + + match res with + | Choice1Of2 x -> return x + | Choice2Of2(:? WebException as wex) when not <| isNull wex.Response -> + use stream = wex.Response.GetResponseStream() + use reader = new StreamReader(stream) + let err = reader.ReadToEnd() + + return + if String.IsNullOrEmpty err then + wex.Reraise() + else + err.ToString() + | Choice2Of2 e -> return failwith(e.ToString()) + + | _ -> + // SECURITY: Reject unknown URL schemes to prevent SSRF attacks via file://, ftp://, etc. + return + failwithf + "Unsupported URL scheme in schema path: '%s'. Only HTTPS is supported for remote schemas (HTTP requires SsrfProtection=false). For local files, ensure the path is absolute or relative to the resolution folder." + resolvedPath } type UniqueNameGenerator() = diff --git a/tests/SwaggerProvider.Tests/SsrfSecurityTests.fs b/tests/SwaggerProvider.Tests/SsrfSecurityTests.fs new file mode 100644 index 0000000..34be244 --- /dev/null +++ b/tests/SwaggerProvider.Tests/SsrfSecurityTests.fs @@ -0,0 +1,352 @@ +namespace SwaggerProvider.Tests.SsrfSecurityTests + +open System +open Xunit +open SwaggerProvider.Internal.SchemaReader + +/// Tests for SSRF protection - Critical: Unknown URL schemes +/// These tests verify that only safe URL schemes are allowed +module UnknownSchemeTests = + + [] + let ``Reject file protocol to prevent local file access``() = + task { + // Test: file:// protocol should be rejected to prevent SSRF via local file access + let fileUrl = "file:///etc/passwd" + + let! ex = + Assert.ThrowsAsync(fun () -> + task { + let! _ = readSchemaPath false "" "" fileUrl + return () + }) + + + Assert.Contains("Unsupported URL scheme", ex.Message) + Assert.Contains("file://", ex.Message) + } + + [] + let ``Reject FTP protocol to prevent remote protocol access``() = + task { + // Test: ftp:// protocol should be rejected to prevent SSRF via FTP + let ftp_url = "ftp://internal-server/schema.json" + + let! ex = + Assert.ThrowsAsync(fun () -> + task { + let! _ = readSchemaPath false "" "" ftp_url + return () + }) + + Assert.Contains("Unsupported URL scheme", ex.Message) + } + + [] + let ``Reject Gopher protocol to prevent remote protocol access``() = + task { + // Test: gopher:// protocol should be rejected to prevent SSRF via Gopher + let gopher_url = "gopher://internal-server/schema.json" + + let! ex = + Assert.ThrowsAsync(fun () -> + task { + let! _ = readSchemaPath false "" "" gopher_url + return () + }) + + Assert.Contains("Unsupported URL scheme", ex.Message) + } + + [] + let ``Reject DICT protocol to prevent remote protocol access``() = + task { + // Test: dict:// protocol should be rejected to prevent SSRF via DICT + let dict_url = "dict://internal-server/schema.json" + + let! ex = + Assert.ThrowsAsync(fun () -> + task { + let! _ = readSchemaPath false "" "" dict_url + return () + }) + + Assert.Contains("Unsupported URL scheme", ex.Message) + } + + +/// Tests for SSRF protection - High: IPv6 private ranges +/// These tests verify that IPv6 loopback, link-local, ULA, multicast addresses are rejected +module IPv6SecurityTests = + + [] + let ``Reject IPv6 loopback address ::1``() = + // Test: IPv6 loopback ::1 should be rejected to prevent access to localhost services + let ipv6_loopback_uri = Uri("https://[::1]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_loopback_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + Assert.Contains("::1", thrown_exception.Message) + + [] + let ``Reject IPv6 link-local address fe80::1``() = + // Test: IPv6 link-local fe80::1 should be rejected to prevent access to link-local services + let ipv6_link_local_uri = Uri("https://[fe80::1]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_link_local_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + + [] + let ``Reject IPv6 unique local address fd00::1``() = + // Test: IPv6 ULA fd00::1 should be rejected to prevent access to private network ranges + let ipv6_ula_uri = Uri("https://[fd00::1]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_ula_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + + [] + let ``Reject IPv6 unique local address fc00::1``() = + // Test: IPv6 ULA fc00::1 should be rejected to prevent access to private network ranges + let ipv6_ula_fc_uri = Uri("https://[fc00::1]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_ula_fc_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + + [] + let ``Reject IPv6 unspecified address ::``() = + // Test: IPv6 unspecified address :: should be rejected to prevent access to localhost services + let ipv6_unspecified_uri = Uri("https://[::]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_unspecified_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + + [] + let ``Reject IPv6 multicast address ff02::1``() = + // Test: IPv6 multicast ff02::1 should be rejected to prevent access to multicast addresses + let ipv6_multicast_uri = Uri("https://[ff02::1]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_multicast_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + + [] + let ``Reject IPv6 multicast address ff00::1``() = + // Test: IPv6 multicast ff00::1 should be rejected to prevent access to multicast addresses + let ipv6_multicast_ff00_uri = Uri("https://[ff00::1]/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false ipv6_multicast_ff00_uri) + + Assert.Contains("private or loopback IPv6 addresses", thrown_exception.Message) + + [] + let ``Allow public IPv6 documentation address 2001:db8::1``() = + // Test: Public IPv6 documentation range 2001:db8::1 should pass SSRF validation + // (Note: May fail due to network access, but SSRF validation should pass) + let public_ipv6_uri = Uri("https://[2001:db8::1]/schema.json") + + try + validateSchemaUrl false public_ipv6_uri + with + | ex when ex.Message.Contains("private or loopback") -> + // SSRF validation failed incorrectly + Assert.True(false, $"Public IPv6 should not be blocked by SSRF validation: {ex.Message}") + | _ -> + // Other errors are also acceptable (network, etc.) + () + + + +/// Tests for IPv4 private ranges +/// These tests verify that IPv4 loopback and private ranges are rejected +module IPv4PrivateRangeTests = + + [] + let ``Reject IPv4 loopback address 127.0.0.1``() = + // Test: IPv4 loopback 127.0.0.1 should be rejected to prevent access to localhost services + let loopback_uri = Uri("https://127.0.0.1/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false loopback_uri) + + Assert.Contains("localhost/loopback", thrown_exception.Message) + + [] + let ``Reject IPv4 private range 10.0.0.0/8``() = + // Test: IPv4 private range 10.0.0.1 should be rejected to prevent access to private networks + let private_10_uri = Uri("https://10.0.0.1/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false private_10_uri) + + Assert.Contains("private or link-local", thrown_exception.Message) + + [] + let ``Reject IPv4 private range 172.16.0.0/12``() = + // Test: IPv4 private range 172.16.0.1 should be rejected to prevent access to private networks + let private_172_uri = Uri("https://172.16.0.1/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false private_172_uri) + + Assert.Contains("private or link-local", thrown_exception.Message) + + [] + let ``Reject IPv4 private range 172.31.255.255``() = + // Test: IPv4 private range upper bound 172.31.255.255 should be rejected + let private_172_upper_uri = Uri("https://172.31.255.255/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false private_172_upper_uri) + + Assert.Contains("private or link-local", thrown_exception.Message) + + [] + let ``Reject IPv4 private range 192.168.0.0/16``() = + // Test: IPv4 private range 192.168.1.1 should be rejected to prevent access to private networks + let private_192_uri = Uri("https://192.168.1.1/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false private_192_uri) + + Assert.Contains("private or link-local", thrown_exception.Message) + + [] + let ``Reject IPv4 link-local address 169.254.0.0/16``() = + // Test: IPv4 link-local 169.254.0.1 should be rejected to prevent access to link-local services + let link_local_uri = Uri("https://169.254.0.1/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false link_local_uri) + + Assert.Contains("private or link-local", thrown_exception.Message) + + +/// Tests for hostname validation +/// These tests verify that localhost hostname and public hostnames are handled correctly +module HostnameValidationTests = + + [] + let ``Reject localhost hostname``() = + // Test: localhost hostname should be rejected to prevent access to localhost services + let localhost_uri = Uri("https://localhost/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false localhost_uri) + + Assert.Contains("localhost/loopback", thrown_exception.Message) + + [] + let ``Allow valid public hostname api.example.com``() = + // Test: Valid public hostname should pass SSRF validation + // (Note: May fail due to network access, but SSRF validation should pass) + let public_uri = Uri("https://api.example.com/schema.json") + + try + validateSchemaUrl false public_uri + with + | ex when ex.Message.Contains("localhost") || ex.Message.Contains("private") -> + Assert.Fail($"Public hostname should not be blocked by SSRF validation: {ex.Message}") + | _ -> () + + +/// Tests for relative file paths +/// These tests verify that relative file paths with __SOURCE_DIRECTORY__ work correctly +module RelativeFilePathTests = + + [] + let ``Allow relative file paths with __SOURCE_DIRECTORY__``() = + task { + // Test: Relative file paths using __SOURCE_DIRECTORY__ should work correctly + // This ensures that development-time file references like: + // let Schema = __SOURCE_DIRECTORY__ + "/../Schemas/v2/petstore.json" + // are properly handled (not rejected by SSRF validation) + let schemaPath = __SOURCE_DIRECTORY__ + "/../Schemas/v2/petstore.json" + + try + let! _ = readSchemaPath false "" "" schemaPath + () // If file exists, that's fine + with + | :? Swagger.OpenApiException -> + // Swagger parsing errors are okay - means file was read + () + | ex when ex.Message.Contains("Schema file not found") -> + // File not found is okay - path was resolved correctly + () + | ex when + ex.Message.Contains("Unsupported URL scheme") + || ex.Message.Contains("localhost") + || ex.Message.Contains("private") + -> + // SSRF validation errors mean relative paths are being blocked - this is the bug we're checking for + Assert.Fail($"Relative file paths should not be rejected by SSRF validation: {ex.Message}") + | _ -> + // Other errors (file reading issues, etc.) are acceptable + () + } + + +/// Tests for disabled SSRF protection (development mode) +/// These tests verify that when SSRF protection is disabled, all addresses are allowed +module SsrfBypassTests = + + [] + let ``Allow IPv4 loopback when ignoreSsrfProtection is true``() = + // Test: IPv4 loopback should be allowed when SSRF protection is disabled + let loopback_uri = Uri("https://127.0.0.1/schema.json") + // Should not throw when ignoreSsrfProtection=true + validateSchemaUrl true loopback_uri + + [] + let ``Allow IPv6 loopback when ignoreSsrfProtection is true``() = + // Test: IPv6 loopback should be allowed when SSRF protection is disabled + let ipv6_loopback_uri = Uri("https://[::1]/schema.json") + // Should not throw when ignoreSsrfProtection=true + validateSchemaUrl true ipv6_loopback_uri + + [] + let ``Allow IPv4 private range when ignoreSsrfProtection is true``() = + // Test: IPv4 private range should be allowed when SSRF protection is disabled + let private_uri = Uri("https://192.168.1.1/schema.json") + // Should not throw when ignoreSsrfProtection=true + validateSchemaUrl true private_uri + + [] + let ``Allow IPv6 private range when ignoreSsrfProtection is true``() = + // Test: IPv6 private range should be allowed when SSRF protection is disabled + let ipv6_private_uri = Uri("https://[fd00::1]/schema.json") + // Should not throw when ignoreSsrfProtection=true + validateSchemaUrl true ipv6_private_uri + + [] + let ``Reject HTTP in production mode``() = + // Test: HTTP should be rejected in production mode (HTTPS only) + let http_url = Uri("http://api.example.com/schema.json") + + let thrown_exception = + Assert.Throws(fun () -> validateSchemaUrl false http_url) + + Assert.Contains("Only HTTPS URLs are allowed", thrown_exception.Message) + + [] + let ``Allow HTTP when ignoreSsrfProtection is true``() = + // Test: HTTP should be allowed when SSRF protection is disabled (development mode) + let http_url = Uri("http://localhost/schema.json") + + try + validateSchemaUrl true http_url + with + | ex when ex.Message.Contains("Only HTTPS") -> + Assert.True(false, $"HTTP should not be rejected by SSRF validation when disabled: {ex.Message}") + | _ -> () diff --git a/tests/SwaggerProvider.Tests/SwaggerProvider.Tests.fsproj b/tests/SwaggerProvider.Tests/SwaggerProvider.Tests.fsproj index 3292314..5b94161 100644 --- a/tests/SwaggerProvider.Tests/SwaggerProvider.Tests.fsproj +++ b/tests/SwaggerProvider.Tests/SwaggerProvider.Tests.fsproj @@ -16,6 +16,7 @@ +