diff --git a/src/workerd/server/server-test.c++ b/src/workerd/server/server-test.c++ index 81c87a4e126..621a81547b8 100644 --- a/src/workerd/server/server-test.c++ +++ b/src/workerd/server/server-test.c++ @@ -131,6 +131,15 @@ class TestStream { KJ_EXPECT(gotCode == expectedCode); } + void recvBytes(int expectedBytes) { + auto actual = readAllBytes(); + if (actual == nullptr) { + KJ_FAIL_REQUIRE("message never received"); + } else { + KJ_EXPECT(actual.size() == expectedBytes); + } + } + void sendHttpGet(kj::StringPtr path, kj::SourceLocation loc = {}) { send(kj::str("GET ", path, " HTTP/1.1\n" @@ -213,6 +222,36 @@ class TestStream { // isEof() may prematurely read a character. Keep it off to the side for the next actual read. kj::Maybe premature; + // return raw bytes from stream + kj::Vector readAllBytes() { + kj::Vector buffer(256); + KJ_IF_SOME(p, premature) { + buffer.add(p); + } + + // Continuously try to read until there's nothing to read (or we've gone way past the size + // expected). + for (;;) { + size_t pos = buffer.size(); + buffer.resize(kj::max(buffer.size() + 256, buffer.capacity())); + + auto promise = stream->tryRead(buffer.begin() + pos, 1, buffer.size() - pos); + if (!promise.poll(ws)) { + // A tryRead() of 1 byte didn't resolve, there must be no data to read. + buffer.resize(pos); + break; + } + size_t n = promise.wait(ws); + if (n == 0) { + buffer.resize(pos); + break; + } + buffer.resize(pos); + }; + buffer.add('\0'); + return buffer; + } + kj::String readAllAvailable() { kj::Vector buffer(256); KJ_IF_SOME(p, premature) { @@ -1472,6 +1511,155 @@ KJ_TEST("Server: capability bindings") { )"_blockquote); } +KJ_TEST("Server: Hyperdrive TLS binding passes sslmode prefer") { + TestServer test(R"(( + services = [ + ( name = "hello", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` let items = []; + ` const connection = await env.hyperdrive.connect(); + ` const encoded = new TextEncoder().encode("hyperdrive-test"); + ` await connection.writable.getWriter().write(new Uint8Array(encoded)); + ` items.push(`Hello from Hyperdrive(${env.hyperdrive.user}) connected TLS stream\n`); + ` return new Response(items.join("")); + ` } + `} + ) + ], + bindings = [ + ( name = "hyperdrive", + hyperdrive = ( + designator = "hyperdrive-outbound", + database = "test-db", + user = "test-user", + password = "test-password", + scheme = "postgresql" + ) + ) + ] + ) + ), + ( name = "hyperdrive-outbound", database = ( + address = "hyperdrive-host", + scheme = "postgresql", + sslmode = "prefer", + tcp = ( + tlsOptions = ( + trustBrowserCas = true + ) + ) + )) + ], + sockets = [ + ( name = "main", + address = "test-addr", + service = "hello" + ) + ] + ))"_kj); + + test.start(); + auto conn = test.connect("test-addr"); + conn.sendHttpGet("/"); + { + auto subreq = test.receiveSubrequest("hyperdrive-host"); + // Avoids deadlock of waiting to read initial message from client + subreq.recvBytes(1); + // For tests respond with "N" for not supported and fall back to plain tcp + subreq.send("N"); + // After connection is made receive message from client + subreq.recv("hyperdrive-test"); + } + conn.recvHttp200(R"( + Hello from Hyperdrive(test-user) connected TLS stream + )"_blockquote); +} + +KJ_TEST("Server: Hyperdrive TLS binding fails sslmode require") { + TestServer test(R"(( + services = [ + ( name = "hello-require", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` try { + ` const connection = await env.hyperdrive.connect(); + ` const encoded = new TextEncoder().encode("hyperdrive-test"); + ` await connection.writable.getWriter().write(new Uint8Array(encoded)); + ` return new Response("Should not reach here"); + ` } catch (e) { + ` return new Response("Connection failed: " + e.message, { status: 500 }); + ` } + ` } + `} + ) + ], + bindings = [ + ( name = "hyperdrive", + hyperdrive = ( + designator = "hyperdrive-outbound", + database = "test-db", + user = "test-user", + password = "test-password", + scheme = "postgresql" + ) + ) + ] + ) + ), + ( name = "hyperdrive-outbound", database = ( + address = "hyperdrive-host", + scheme = "postgresql", + sslmode = "require", + tcp = ( + tlsOptions = ( + trustBrowserCas = true + ) + ) + )) + ], + sockets = [ + ( name = "main", + address = "test-addr", + service = "hello-require" + ) + ] + ))"_kj); + + KJ_EXPECT_LOG(ERROR, "Server does not support SSL, but client requires it"); + KJ_EXPECT_LOG(WARNING, "failed to connect to local database"); + KJ_EXPECT_LOG(ERROR, "unexpected error connecting to database"); + KJ_EXPECT_LOG(ERROR, "unexpected error connecting to database"); + + test.start(); + auto conn = test.connect("test-addr"); + conn.sendHttpGet("/"); + { + auto subreq = test.receiveSubrequest("hyperdrive-host"); + // Avoids deadlock of waiting to read initial message from client + subreq.recvBytes(1); + // For tests respond with "N" for not supported, which should fail since sslmode = 'require' + subreq.send("N"); + } + + conn.recvRegex(R"( + HTTP/1.1 500 Internal Server Error + Content-Length: 71 + Content-Type: text/plain;charset=UTF-8 + + Connection failed: internal error.*)"_blockquote); +} + KJ_TEST("Server: cyclic bindings") { TestServer test(R"(( services = [ diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index ac222f197f0..ac9ad56dcd4 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -1324,6 +1324,334 @@ kj::Own Server::makeDiskDirectoryService(kj::StringPtr name, } } +class Server::DatabaseService final: public Service, private WorkerInterface { + public: + DatabaseService(kj::Own addrParam) + : addr(kj::mv(addrParam)) {} + + DatabaseService(kj::Own addrParam, + kj::Own tlsContextParam, + kj::String hostnameParam, + kj::String schemeParam, + kj::String sslmodeParam) + : addr(kj::mv(addrParam)), tlsContext(kj::mv(tlsContextParam)), realHostname(kj::mv(hostnameParam)), scheme(kj::mv(schemeParam)), sslmode(kj::mv(sslmodeParam)) { + } + + kj::Own startRequest(IoChannelFactory::SubrequestMetadata metadata) override { + return {this, kj::NullDisposer::instance}; + } + + bool hasHandler(kj::StringPtr handlerName) override { + return handlerName == "fetch"_kj || handlerName == "connect"_kj; + } + + private: + kj::Own addr; + kj::Own tlsContext; + kj::String realHostname; + kj::String scheme; + kj::String sslmode; + + kj::Promise request(kj::HttpMethod method, + kj::StringPtr url, + const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, + kj::HttpService::Response& response) override { + throwUnsupported(); + } + + kj::Promise connect(kj::StringPtr host, + const kj::HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& tunnel, + kj::HttpConnectSettings settings) override { + TRACE_EVENT("workerd", "DatabaseService::connect()", "host", host.cStr()); + auto io_stream = co_await addr->connect(); + bool clientTunnelAccepted = false; + bool sslmodePrefer = sslmode == nullptr || sslmode == "" || sslmode == "prefer"; + bool sslmodeRequire = sslmode != nullptr && sslmode == "require"; + bool retryConnection = false; + if (sslmodePrefer || sslmodeRequire) { + try { + // Connect with plain TCP + if(scheme == "postgres" || scheme == "postgresql") { + // Send sslRequest packet + kj::byte sslRequest[8] = { 0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f }; + co_await io_stream->write(kj::arrayPtr(sslRequest, sizeof(sslRequest))); + + // Read 1 byte response ('S' = SSL supported, 'N' = not supported) + kj::byte response[1]; + co_await io_stream->read(response, 1); + if((char)response[0] == 'S') { + // Upgrade to TLS stream + io_stream = co_await tlsContext->wrapClient( + kj::mv(io_stream), + realHostname + ); + } else { + if(sslmodeRequire) { + // Server doesn't support SSL but we require it + auto errorMessage = kj::str("Server does not support SSL, but client requires it"); + KJ_LOG(ERROR, errorMessage); + if (!clientTunnelAccepted) { + kj::HttpHeaderTable headerTable; + kj::HttpHeaders headers(headerTable); + auto body = tunnel.reject(501, "Not Implemented"_kj, headers, errorMessage.size()); + co_await body->write(errorMessage.asBytes()); + } + co_return; + } + // Fall back to plain tcp connection without tls + } + } else if (scheme == "mysql") { + // Read MySQL initial handshake packet from server + kj::byte initPacketHeader[4]; + co_await io_stream->read(kj::arrayPtr(initPacketHeader, 4), 4); + // Parse little-endian of first 3 header bytes to get payload length + uint32_t payloadLength = static_cast(initPacketHeader[0]) | + (static_cast(initPacketHeader[1]) << 8) | + (static_cast(initPacketHeader[2]) << 16); + auto initPacketPayload = kj::heapArray(payloadLength); + co_await io_stream->read(initPacketPayload, payloadLength); + + // Check if server supports SSL + if(serverSupportsSSL(initPacketPayload.asPtr(), payloadLength)) { + // Send our own fixed SSL request to server + kj::byte ssl_request_packet[36] = { + 0x20, 0x00, 0x00, 0x01, // payload length = 32, sequence id = 1 + 0x00, 0x0A, 0x00, 0x00, // capability_flags = CLIENT_SSL | CLIENT_PROTOCOL_41 + 0x00, 0x00, 0x00, 0x00, // max_packet_size = 0 + 0x21, // character_set = utf8_general_ci + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00 + }; + co_await io_stream->write(kj::arrayPtr(ssl_request_packet, sizeof(ssl_request_packet))); + // Upgrade to TLS + io_stream = co_await tlsContext->wrapClient(kj::mv(io_stream), realHostname); + // Accept client connect and send handshake back to client + tunnel.accept(200, "OK", kj::HttpHeaders(kj::HttpHeaderTable{})); + clientTunnelAccepted = true; + + // After successfully upgrading the stream to TLS we need to handle the auth + // packets sent from the client. + // + // Client will send auth packets with wrong sequence ID and we need to manually fix them. + // MySQL db will reject the connection if the sequence ID is out of order + // and the mysql driver will complain about packets out of order as well. + co_await connection.write(kj::arrayPtr(initPacketHeader, 4)); + co_await connection.write(initPacketPayload); + + kj::byte clientHeader[4]; + co_await connection.read(kj::arrayPtr(clientHeader, 4), 4); + + // Parse little-endian of first 3 header bytes to get payload length + uint32_t clientPayloadLen = static_cast(clientHeader[0]) | + (static_cast(clientHeader[1]) << 8) | + (static_cast(clientHeader[2]) << 16); + + auto clientPayload = kj::heapArray(clientPayloadLen); + co_await connection.read(clientPayload, clientPayloadLen); + + // Increment sequence ID and write the client auth payload to mysql server + clientHeader[3]++; + co_await io_stream->write(kj::arrayPtr(clientHeader, 4)); + co_await io_stream->write(clientPayload); + + // Read and forward server's response, fixing sequence ID + kj::byte serverHeader[4]; + co_await io_stream->read(kj::arrayPtr(serverHeader, 4), 4); + + uint32_t serverPayloadLen = static_cast(serverHeader[0]) | + (static_cast(serverHeader[1]) << 8) | + (static_cast(serverHeader[2]) << 16); + + auto serverPayload = kj::heapArray(serverPayloadLen); + co_await io_stream->read(serverPayload, serverPayloadLen); + + // Increment sequence ID and write the server auth response to client + serverHeader[3]--; + co_await connection.write(kj::arrayPtr(serverHeader, 4)); + co_await connection.write(serverPayload); + } else { + if(sslmodeRequire) { + // Server doesn't support SSL but we require it + auto errorMessage = kj::str("Server does not support SSL, but client requires it"); + KJ_LOG(ERROR, errorMessage); + if (!clientTunnelAccepted) { + kj::HttpHeaderTable headerTable; + kj::HttpHeaders headers(headerTable); + auto body = tunnel.reject(501, "Not Implemented"_kj, headers, errorMessage.size()); + co_await body->write(errorMessage.asBytes()); + } + co_return; + } else { + retryConnection = true; + } + + // fall back to plain tcp + } + } + } catch (kj::Exception e) { + // Fail the client connection when sslmode is 'require' + if(sslmodeRequire) { + throw e; + } else { + retryConnection = true; + } + } + } + auto promises = kj::heapArrayBuilder>(2); + if (retryConnection) { + // In the case where we attempted tls upgrade connection failed, we will reconnect with plain tcp + io_stream = co_await addr->connect(); + } + promises.add(connection.pumpTo(*io_stream).then([&io_stream = *io_stream](uint64_t size) { + io_stream.shutdownWrite(); + })); + + promises.add(io_stream->pumpTo(connection).then([&connection](uint64_t size) { + connection.shutdownWrite(); + })); + + if(!clientTunnelAccepted) { + tunnel.accept(200, "OK", kj::HttpHeaders(kj::HttpHeaderTable{})); + } + + co_await kj::joinPromisesFailFast(promises.finish()).attach(kj::mv(io_stream)); + } + + // MySQL helper that skips through reading the init packet + // until we get to the ssl capability flag, only need to look at lower 16 bits + // docs: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html + bool serverSupportsSSL(kj::ArrayPtr buf, size_t len) { + // Skip protocol_version + size_t offset = 1; + // Find end of server_version string (null terminator) + while (offset < len && buf[offset] != 0x00) + offset++; + // Skip null terminator + offset++; + + // Ensure there are enough bytes left for fixed fields + if (offset + 4 + 8 + 1 + 2 > len) + return false; + + // Skip connection_id, auth_plugin_data_part_1, filler + offset += 4; + offset += 8; + offset += 1; + + // Read 2-byte little-endian capability_flags_lower + uint16_t caps = buf[offset] | (buf[offset + 1] << 8); + + // Check SSL support + return (caps & 2048) != 0; + } + + kj::Promise prewarm(kj::StringPtr url) override { + return kj::READY_NOW; + } + kj::Promise runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override { + throwUnsupported(); + } + kj::Promise runAlarm(kj::Date scheduledTime, uint32_t retryCount) override { + throwUnsupported(); + } + kj::Promise customEvent(kj::Own event) override { + return event->notSupported(); + } + + [[noreturn]] void throwUnsupported() { + JSG_FAIL_REQUIRE(Error, "Database servers don't support this event type."); + } +}; + +kj::Own Server::makeDatabaseService(kj::StringPtr name, + config::DatabaseServer::Reader conf, + kj::HttpHeaderTable::Builder& headerTableBuilder) { + TRACE_EVENT("workerd", "Server::makeDatabaseService()", "name", name.cStr()); + kj::StringPtr addrStr = nullptr; + kj::String ownAddrStr = nullptr; + + KJ_IF_SOME(override, externalOverrides.findEntry(name)) { + addrStr = ownAddrStr = kj::mv(override.value); + externalOverrides.erase(override); + } else if (conf.hasAddress()) { + addrStr = conf.getAddress(); + } else { + reportConfigError(kj::str("Database service \"", name, + "\" has no address in the config, so must be specified " + "on the command line with `--external-addr`.")); + return makeInvalidConfigService(); + } + + // Extract real hostname (without port) and port from address + kj::String tlsHostname; + kj::Maybe extractedPort; + auto colonPos = addrStr.findLast(':'); + KJ_IF_SOME(pos, colonPos) { + tlsHostname = kj::str(addrStr.slice(0, pos)); + // Extract port (everything after colon, up to any '/') + auto afterColon = addrStr.slice(pos + 1); + KJ_IF_SOME(slashPos, afterColon.findFirst('/')) { + auto portStr = kj::str(afterColon.slice(0, slashPos)); + extractedPort = portStr.parseAs(); + } else { + extractedPort = afterColon.parseAs(); + } + } else { + tlsHostname = kj::str(addrStr); + } + + // Set scheme and database port + kj::StringPtr schemePtr = nullptr; + if (conf.hasScheme()) { + schemePtr = conf.getScheme(); + } + + int databasePort; + KJ_IF_SOME(port, extractedPort) { + // Port was explicitly provided in address, use it + databasePort = port; + } else { + // No port in address, use default based on scheme + if (schemePtr == nullptr || schemePtr == "postgres"_kj || schemePtr == "postgresql"_kj) { + databasePort = 5432; + } else if (schemePtr == "mysql"_kj) { + databasePort = 3306; + } else { + reportConfigError(kj::str("Database service \"", name, + "\" has invalid scheme in the config")); + return makeInvalidConfigService(); + } + } + + // Set up TLS configurations if provided + auto tcpConf = conf.getTcp(); + + // Force a worker to connect over TLS by default + kj::StringPtr sslmodePtr = conf.getSslmode(); + auto addr = kj::heap(network.parseAddress(addrStr, databasePort)); + if (tcpConf.hasTlsOptions() || (sslmodePtr != nullptr && sslmodePtr != "")) { + kj::Maybe certificateHost; + if (tcpConf.hasCertificateHost()) { + certificateHost = tcpConf.getCertificateHost(); + } + auto tlsContext = makeTlsContext(tcpConf.getTlsOptions()); + addr = kj::heap( + makeTlsNetworkAddress(tcpConf.getTlsOptions(), addrStr, certificateHost, databasePort)); + + kj::String scheme = kj::str(schemePtr); + kj::String sslmode = kj::str(sslmodePtr); + return kj::refcounted(kj::mv(addr), kj::mv(tlsContext), kj::mv(tlsHostname), kj::mv(scheme), kj::mv(sslmode)); + } + + // Default service without TLS + return kj::refcounted(kj::mv(addr)); +} + // ======================================================================================= // This class exists to update the InspectorService's table of isolates when a config @@ -4653,6 +4981,9 @@ kj::Promise> Server::makeService(config::Service::Reade case config::Service::DISK: co_return makeDiskDirectoryService(name, conf.getDisk(), headerTableBuilder); + + case config::Service::DATABASE: + co_return makeDatabaseService(name, conf.getDatabase(), headerTableBuilder); } reportConfigError(kj::str("Service named \"", name, diff --git a/src/workerd/server/server.h b/src/workerd/server/server.h index 379c1b00497..a578421707d 100644 --- a/src/workerd/server/server.h +++ b/src/workerd/server/server.h @@ -220,6 +220,9 @@ class Server final: private kj::TaskSet::ErrorHandler { kj::Own makeDiskDirectoryService(kj::StringPtr name, config::DiskDirectory::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder); + kj::Own makeDatabaseService(kj::StringPtr name, + config::DatabaseServer::Reader conf, + kj::HttpHeaderTable::Builder& headerTableBuilder); kj::Promise> makeWorker(kj::StringPtr name, config::Worker::Reader conf, capnp::List::Reader extensions); @@ -252,6 +255,7 @@ class Server final: private kj::TaskSet::ErrorHandler { class ExternalTcpService; class NetworkService; class DiskDirectoryService; + class DatabaseService; class WorkerService; class WorkerEntrypointService; class HttpListener; diff --git a/src/workerd/server/workerd.capnp b/src/workerd/server/workerd.capnp index 7f2422c61ce..b923d9fc3ac 100644 --- a/src/workerd/server/workerd.capnp +++ b/src/workerd/server/workerd.capnp @@ -185,6 +185,12 @@ struct Service { # An HTTP service backed by a directory on disk, supporting a basic HTTP GET/PUT. Generally # not intended to be exposed directly to the internet; typically you want to bind this into # a Worker that adds logic for setting Content-Type and the like. + + database @6 :DatabaseServer; + # A service that fowards all requests to a specific remote database server. + # This differs from ExternalServer because this allows clients to control + # upgrading TLS for their external database. Useful for database drivers + # that call startTls after initial tcp connection has been opened. } # TODO(someday): Allow defining a list of middlewares to stack on top of the service. This would @@ -783,6 +789,35 @@ struct ExternalServer { } } +struct DatabaseServer { + # Defines a custom database egress where the client + # can control the tcp communication to the external database. + # This is useful for cases such as a database driver will + # optionally call startTls. + + address @0 :Text; + # Address/port of the server. Optional; if not specified, then you will be required to specify + # the address on the command line with with `--external-addr =`. + + tcp :group { + # Connect to the server over raw TCP. Bindings to this service will only support the + # `connect()` method; `fetch()` will throw an exception. + tlsOptions @1 :TlsOptions; + certificateHost @2 :Text; + } + + sslmode @3 :Text = "prefer"; + # sslmode will give the client the ability to upgrade TLS stream + # after initial connection is created. + # SSL modes supported are 'prefer' and 'require' + # prefer - Try TLS connection, if that fails fall back to plain tcp + # require - Client requires server to accept ssl, fails if TLS connection is not established + # Default behavior is 'prefer' + + scheme @4 :Text; + # Database scheme defines either Postgres or MySQL database. +} + struct Network { # Describes the ability to talk to a network. #