From b9d87b588ab5c78607081e9d081fa02dc9208af3 Mon Sep 17 00:00:00 2001 From: Bartosz Majsak Date: Thu, 18 Sep 2025 08:54:57 +0200 Subject: [PATCH] feat(server): enables TLS mode This change is made to support environments that require HTTPS. The server can now run over HTTPS with either user-provided certs or self-signed ones. This is done through flags aligned with vLLM (`--ssl-certfile`, `--ssl-keyfile`). Additionally `--self-signed-certs` has been provided for self-signed certs. Signed-off-by: Bartosz Majsak --- pkg/common/config.go | 25 ++++ pkg/llm-d-inference-sim/server.go | 23 +++- .../server_fixture_test.go | 42 ++++++ pkg/llm-d-inference-sim/server_test.go | 92 +++++++++++++ pkg/llm-d-inference-sim/server_tls.go | 121 ++++++++++++++++++ pkg/llm-d-inference-sim/simulator_test.go | 4 + 6 files changed, 300 insertions(+), 7 deletions(-) create mode 100644 pkg/llm-d-inference-sim/server_fixture_test.go create mode 100644 pkg/llm-d-inference-sim/server_tls.go diff --git a/pkg/common/config.go b/pkg/common/config.go index c367c029..ca8e5aa5 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -174,6 +174,13 @@ type Configuration struct { // DPSize is data parallel size - a number of ranks to run, minimum is 1, maximum is 8, default is 1 DPSize int `yaml:"data-parallel-size" json:"data-parallel-size"` + + // SSLCertFile is the path to the SSL certificate file for HTTPS + SSLCertFile string `yaml:"ssl-certfile" json:"ssl-certfile"` + // SSLKeyFile is the path to the SSL private key file for HTTPS + SSLKeyFile string `yaml:"ssl-keyfile" json:"ssl-keyfile"` + // SelfSignedCerts enables automatic generation of self-signed certificates for HTTPS + SelfSignedCerts bool `yaml:"self-signed-certs" json:"self-signed-certs"` } type Metrics struct { @@ -469,9 +476,23 @@ func (c *Configuration) validate() error { if c.DPSize < 1 || c.DPSize > 8 { return errors.New("data parallel size must be between 1 ans 8") } + + if (c.SSLCertFile == "") != (c.SSLKeyFile == "") { + return errors.New("both ssl-certfile and ssl-keyfile must be provided together") + } + + if c.SelfSignedCerts && (c.SSLCertFile != "" || c.SSLKeyFile != "") { + return errors.New("cannot use both self-signed-certs and explicit ssl-certfile/ssl-keyfile") + } + return nil } +// SSLEnabled returns true if SSL is enabled either via certificate files or self-signed certificates +func (c *Configuration) SSLEnabled() bool { + return (c.SSLCertFile != "" && c.SSLKeyFile != "") || c.SelfSignedCerts +} + func (c *Configuration) Copy() (*Configuration, error) { var dst Configuration data, err := json.Marshal(c) @@ -552,6 +573,10 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.Var(&dummyFailureTypes, "failure-types", failureTypesDescription) f.Lookup("failure-types").NoOptDefVal = dummy + f.StringVar(&config.SSLCertFile, "ssl-certfile", config.SSLCertFile, "Path to SSL certificate file for HTTPS (optional)") + f.StringVar(&config.SSLKeyFile, "ssl-keyfile", config.SSLKeyFile, "Path to SSL private key file for HTTPS (optional)") + f.BoolVar(&config.SelfSignedCerts, "self-signed-certs", config.SelfSignedCerts, "Enable automatic generation of self-signed certificates for HTTPS") + // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string f.StringVar(&dummyString, "config", "", "The path to a yaml configuration file. The command line values overwrite the configuration file values") diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 5fb77d5e..1c6284a1 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -42,7 +42,7 @@ func (s *VllmSimulator) newListener() (net.Listener, error) { return listener, nil } -// startServer starts http server on port defined in command line +// startServer starts http/https server on port defined in command line func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { r := fasthttprouter.New() @@ -61,23 +61,32 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) r.GET("/ready", s.HandleReady) r.POST("/tokenize", s.HandleTokenize) - server := fasthttp.Server{ + server := &fasthttp.Server{ ErrorHandler: s.HandleError, Handler: r.Handler, Logger: s, } + if err := s.configureSSL(server); err != nil { + return err + } + // Start server in a goroutine serverErr := make(chan error, 1) go func() { - s.logger.Info("HTTP server starting") - serverErr <- server.Serve(listener) + if s.config.SSLEnabled() { + s.logger.Info("Server starting", "protocol", "HTTPS", "port", s.config.Port) + serverErr <- server.ServeTLS(listener, "", "") + } else { + s.logger.Info("Server starting", "protocol", "HTTP", "port", s.config.Port) + serverErr <- server.Serve(listener) + } }() // Wait for either context cancellation or server error select { case <-ctx.Done(): - s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") + s.logger.Info("Shutdown signal received, shutting down server gracefully") // Gracefully shutdown the server if err := server.Shutdown(); err != nil { @@ -85,12 +94,12 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) return err } - s.logger.Info("HTTP server stopped") + s.logger.Info("Server stopped") return nil case err := <-serverErr: if err != nil { - s.logger.Error(err, "HTTP server failed") + s.logger.Error(err, "Server failed") } return err } diff --git a/pkg/llm-d-inference-sim/server_fixture_test.go b/pkg/llm-d-inference-sim/server_fixture_test.go new file mode 100644 index 00000000..96762a9b --- /dev/null +++ b/pkg/llm-d-inference-sim/server_fixture_test.go @@ -0,0 +1,42 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "os" + "path/filepath" +) + +// GenerateTempCerts creates temporary SSL certificate and key files for testing +func GenerateTempCerts(tempDir string) (certFile, keyFile string, err error) { + certPEM, keyPEM, err := CreateSelfSignedTLSCertificatePEM() + if err != nil { + return "", "", err + } + + certFile = filepath.Join(tempDir, "cert.pem") + if err := os.WriteFile(certFile, certPEM, 0644); err != nil { + return "", "", err + } + + keyFile = filepath.Join(tempDir, "key.pem") + if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { + return "", "", err + } + + return certFile, keyFile, nil +} diff --git a/pkg/llm-d-inference-sim/server_test.go b/pkg/llm-d-inference-sim/server_test.go index ee1dfd7c..1f610562 100644 --- a/pkg/llm-d-inference-sim/server_test.go +++ b/pkg/llm-d-inference-sim/server_test.go @@ -31,6 +31,7 @@ import ( ) var _ = Describe("Server", func() { + It("Should respond to /health", func() { ctx := context.TODO() client, err := startServer(ctx, common.ModeRandom) @@ -116,4 +117,95 @@ var _ = Describe("Server", func() { Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) }) }) + + Context("SSL/HTTPS Configuration", func() { + It("Should parse SSL certificate configuration correctly", func() { + tempDir := GinkgoT().TempDir() + certFile, keyFile, err := GenerateTempCerts(tempDir) + Expect(err).NotTo(HaveOccurred()) + + oldArgs := os.Args + defer func() { + os.Args = oldArgs + }() + + os.Args = []string{"cmd", "--model", model, "--ssl-certfile", certFile, "--ssl-keyfile", keyFile} + config, err := common.ParseCommandParamsAndLoadConfig() + Expect(err).NotTo(HaveOccurred()) + Expect(config.SSLEnabled()).To(BeTrue()) + Expect(config.SSLCertFile).To(Equal(certFile)) + Expect(config.SSLKeyFile).To(Equal(keyFile)) + }) + + It("Should parse self-signed certificate configuration correctly", func() { + oldArgs := os.Args + defer func() { + os.Args = oldArgs + }() + + os.Args = []string{"cmd", "--model", model, "--self-signed-certs"} + config, err := common.ParseCommandParamsAndLoadConfig() + Expect(err).NotTo(HaveOccurred()) + Expect(config.SSLEnabled()).To(BeTrue()) + Expect(config.SelfSignedCerts).To(BeTrue()) + }) + + It("Should create self-signed TLS certificate successfully", func() { + cert, err := CreateSelfSignedTLSCertificate() + Expect(err).NotTo(HaveOccurred()) + Expect(cert.Certificate).To(HaveLen(1)) + Expect(cert.PrivateKey).NotTo(BeNil()) + }) + + It("Should validate SSL configuration - both cert and key required", func() { + tempDir := GinkgoT().TempDir() + + oldArgs := os.Args + defer func() { + os.Args = oldArgs + }() + + certFile, _, err := GenerateTempCerts(tempDir) + Expect(err).NotTo(HaveOccurred()) + + os.Args = []string{"cmd", "--model", model, "--ssl-certfile", certFile} + _, err = common.ParseCommandParamsAndLoadConfig() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("both ssl-certfile and ssl-keyfile must be provided together")) + + _, keyFile, err := GenerateTempCerts(tempDir) + Expect(err).NotTo(HaveOccurred()) + + os.Args = []string{"cmd", "--model", model, "--ssl-keyfile", keyFile} + _, err = common.ParseCommandParamsAndLoadConfig() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("both ssl-certfile and ssl-keyfile must be provided together")) + }) + + It("Should start HTTPS server with provided SSL certificates", func(ctx SpecContext) { + tempDir := GinkgoT().TempDir() + certFile, keyFile, err := GenerateTempCerts(tempDir) + Expect(err).NotTo(HaveOccurred()) + + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--ssl-certfile", certFile, "--ssl-keyfile", keyFile} + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Get("https://localhost/health") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + It("Should start HTTPS server with self-signed certificates", func(ctx SpecContext) { + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--self-signed-certs"} + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Get("https://localhost/health") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + }) }) diff --git a/pkg/llm-d-inference-sim/server_tls.go b/pkg/llm-d-inference-sim/server_tls.go new file mode 100644 index 00000000..601418d7 --- /dev/null +++ b/pkg/llm-d-inference-sim/server_tls.go @@ -0,0 +1,121 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "time" + + "github.com/valyala/fasthttp" +) + +// Based on: https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/8d01161ec48d6b49cd371f179551b35da46e6fd6/internal/tls/tls.go +func (s *VllmSimulator) configureSSL(server *fasthttp.Server) error { + if !s.config.SSLEnabled() { + return nil + } + + var cert tls.Certificate + var err error + + if s.config.SSLCertFile != "" && s.config.SSLKeyFile != "" { + s.logger.Info("HTTPS server starting with certificate files", "cert", s.config.SSLCertFile, "key", s.config.SSLKeyFile) + cert, err = tls.LoadX509KeyPair(s.config.SSLCertFile, s.config.SSLKeyFile) + } else if s.config.SelfSignedCerts { + s.logger.Info("HTTPS server starting with self-signed certificate") + cert, err = CreateSelfSignedTLSCertificate() + } + + if err != nil { + s.logger.Error(err, "failed to create TLS certificate") + return err + } + + server.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + }, + } + + return nil +} + +// CreateSelfSignedTLSCertificatePEM creates a self-signed cert and returns the PEM-encoded certificate and key bytes +func CreateSelfSignedTLSCertificatePEM() (certPEM, keyPEM []byte, err error) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, fmt.Errorf("error creating serial number: %v", err) + } + now := time.Now() + notBefore := now.UTC() + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"llm-d Inference Simulator"}, + }, + NotBefore: notBefore, + NotAfter: now.Add(time.Hour * 24 * 365 * 10).UTC(), // 10 years + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + priv, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, fmt.Errorf("error generating key: %v", err) + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("error creating certificate: %v", err) + } + + certBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, nil, fmt.Errorf("error marshalling private key: %v", err) + } + keyBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + + return certBytes, keyBytes, nil +} + +// CreateSelfSignedTLSCertificate creates a self-signed cert the server can use to serve TLS. +// Original code: https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/8d01161ec48d6b49cd371f179551b35da46e6fd6/internal/tls/tls.go +func CreateSelfSignedTLSCertificate() (tls.Certificate, error) { + certPEM, keyPEM, err := CreateSelfSignedTLSCertificatePEM() + if err != nil { + return tls.Certificate{}, err + } + return tls.X509KeyPair(certPEM, keyPEM) +} diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 59f92175..e504c5d5 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -18,6 +18,7 @@ package llmdinferencesim import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -141,6 +142,9 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return listener.Dial() }, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, }, }, nil }