diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim index a88a2d2e436fa..6bd9cc9892edf 100644 --- a/lib/pure/asynchttpserver.nim +++ b/lib/pure/asynchttpserver.nim @@ -146,15 +146,30 @@ proc respondError(req: Request, code: HttpCode): Future[void] = result = req.client.send(msg) proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] = + template invalidProtocol() = + raise newException(ValueError, "Invalid request protocol. Got: " & protocol) + result = default(tuple[orig: string, major, minor: int]) var i = protocol.skipIgnoreCase("HTTP/") if i != 5: - raise newException(ValueError, "Invalid request protocol. Got: " & - protocol) + invalidProtocol() result.orig = protocol - i.inc protocol.parseSaturatedNatural(result.major, i) + let majorLen = protocol.parseSaturatedNatural(result.major, i) + if majorLen == 0: + invalidProtocol() + i.inc majorLen + + if i >= protocol.len or protocol[i] != '.': + invalidProtocol() i.inc # Skip . - i.inc protocol.parseSaturatedNatural(result.minor, i) + + let minorLen = protocol.parseSaturatedNatural(result.minor, i) + if minorLen == 0: + invalidProtocol() + i.inc minorLen + + if i != protocol.len: + invalidProtocol() proc sendStatus(client: AsyncSocket, status: string): Future[void] = client.send("HTTP/1.1 " & status & "\c\L\c\L") diff --git a/tests/stdlib/tasynchttpserver.nim b/tests/stdlib/tasynchttpserver.nim index 5a7e2da4018a9..3df8b125868c4 100644 --- a/tests/stdlib/tasynchttpserver.nim +++ b/tests/stdlib/tasynchttpserver.nim @@ -9,7 +9,7 @@ import strutils from net import TimeoutError import std/assertions -import httpclient, asynchttpserver, asyncdispatch, asyncfutures +import httpclient, asynchttpserver, asyncdispatch, asyncfutures, asyncnet template runTest( handler: proc (request: Request): Future[void] {.gcsafe.}, @@ -113,9 +113,39 @@ proc testCustomContentLength() {.async.} = runTest(handler, request, test) +proc testMalformedProtocolsDoNotCrash() {.async.} = + for rawProtocol in [ + "HTTP/1", + "HTTP/.1", + "HTTP/1.", + "HTTP/1.1xyz", + ]: + let server = newAsyncHttpServer() + server.listen(Port(0)) + + var callbackCalled = false + proc handler(request: Request) {.async.} = + callbackCalled = true + await request.respond(Http200, "unexpected") + + asyncCheck server.acceptRequest(handler) + + let client = await asyncnet.dial("127.0.0.1", server.getPort) + await client.send("GET / " & rawProtocol & "\c\LHost: localhost\c\L\c\L") + + let lineFut = client.recvLine() + doAssert await withTimeout(lineFut, 1000), "Timed out waiting for response to " & rawProtocol + doAssert lineFut.read() == "HTTP/1.1 400 Bad Request" + + client.close() + await sleepAsync(100) + doAssert not callbackCalled + server.close() + waitFor(test200()) waitFor(test404()) waitFor(testCustomEmptyHeaders()) waitFor(testCustomContentLength()) +waitFor(testMalformedProtocolsDoNotCrash()) echo "OK"