Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pkg/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ type Configuration struct {

// DPSize is data parallel size - a number of ranks to run, minimum is 1, maximum is 8, default is 1
DPSize int `yaml:"data-parallel-size" json:"data-parallel-size"`

// SSLCertFile is the path to the SSL certificate file for HTTPS
SSLCertFile string `yaml:"ssl-certfile" json:"ssl-certfile"`
// SSLKeyFile is the path to the SSL private key file for HTTPS
SSLKeyFile string `yaml:"ssl-keyfile" json:"ssl-keyfile"`
// SelfSignedCerts enables automatic generation of self-signed certificates for HTTPS
SelfSignedCerts bool `yaml:"self-signed-certs" json:"self-signed-certs"`
}

type Metrics struct {
Expand Down Expand Up @@ -469,9 +476,23 @@ func (c *Configuration) validate() error {
if c.DPSize < 1 || c.DPSize > 8 {
return errors.New("data parallel size must be between 1 ans 8")
}

if (c.SSLCertFile == "") != (c.SSLKeyFile == "") {
return errors.New("both ssl-certfile and ssl-keyfile must be provided together")
}

if c.SelfSignedCerts && (c.SSLCertFile != "" || c.SSLKeyFile != "") {
return errors.New("cannot use both self-signed-certs and explicit ssl-certfile/ssl-keyfile")
}

return nil
}

// SSLEnabled returns true if SSL is enabled either via certificate files or self-signed certificates
func (c *Configuration) SSLEnabled() bool {
return (c.SSLCertFile != "" && c.SSLKeyFile != "") || c.SelfSignedCerts
}

func (c *Configuration) Copy() (*Configuration, error) {
var dst Configuration
data, err := json.Marshal(c)
Expand Down Expand Up @@ -552,6 +573,10 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
f.Var(&dummyFailureTypes, "failure-types", failureTypesDescription)
f.Lookup("failure-types").NoOptDefVal = dummy

f.StringVar(&config.SSLCertFile, "ssl-certfile", config.SSLCertFile, "Path to SSL certificate file for HTTPS (optional)")
f.StringVar(&config.SSLKeyFile, "ssl-keyfile", config.SSLKeyFile, "Path to SSL private key file for HTTPS (optional)")
f.BoolVar(&config.SelfSignedCerts, "self-signed-certs", config.SelfSignedCerts, "Enable automatic generation of self-signed certificates for HTTPS")

// These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help
var dummyString string
f.StringVar(&dummyString, "config", "", "The path to a yaml configuration file. The command line values overwrite the configuration file values")
Expand Down
23 changes: 16 additions & 7 deletions pkg/llm-d-inference-sim/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *VllmSimulator) newListener() (net.Listener, error) {
return listener, nil
}

// startServer starts http server on port defined in command line
// startServer starts http/https server on port defined in command line
func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error {
r := fasthttprouter.New()

Expand All @@ -61,36 +61,45 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
r.GET("/ready", s.HandleReady)
r.POST("/tokenize", s.HandleTokenize)

server := fasthttp.Server{
server := &fasthttp.Server{
ErrorHandler: s.HandleError,
Handler: r.Handler,
Logger: s,
}

if err := s.configureSSL(server); err != nil {
return err
}

// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
s.logger.Info("HTTP server starting")
serverErr <- server.Serve(listener)
if s.config.SSLEnabled() {
s.logger.Info("Server starting", "protocol", "HTTPS", "port", s.config.Port)
serverErr <- server.ServeTLS(listener, "", "")
} else {
s.logger.Info("Server starting", "protocol", "HTTP", "port", s.config.Port)
serverErr <- server.Serve(listener)
}
}()

// Wait for either context cancellation or server error
select {
case <-ctx.Done():
s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully")
s.logger.Info("Shutdown signal received, shutting down server gracefully")

// Gracefully shutdown the server
if err := server.Shutdown(); err != nil {
s.logger.Error(err, "Error during server shutdown")
return err
}

s.logger.Info("HTTP server stopped")
s.logger.Info("Server stopped")
return nil

case err := <-serverErr:
if err != nil {
s.logger.Error(err, "HTTP server failed")
s.logger.Error(err, "Server failed")
}
return err
}
Expand Down
42 changes: 42 additions & 0 deletions pkg/llm-d-inference-sim/server_fixture_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package llmdinferencesim

import (
"os"
"path/filepath"
)

// GenerateTempCerts creates temporary SSL certificate and key files for testing
func GenerateTempCerts(tempDir string) (certFile, keyFile string, err error) {
certPEM, keyPEM, err := CreateSelfSignedTLSCertificatePEM()
if err != nil {
return "", "", err
}

certFile = filepath.Join(tempDir, "cert.pem")
if err := os.WriteFile(certFile, certPEM, 0644); err != nil {
return "", "", err
}

keyFile = filepath.Join(tempDir, "key.pem")
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
return "", "", err
}

return certFile, keyFile, nil
}
92 changes: 92 additions & 0 deletions pkg/llm-d-inference-sim/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
)

var _ = Describe("Server", func() {

It("Should respond to /health", func() {
ctx := context.TODO()
client, err := startServer(ctx, common.ModeRandom)
Expand Down Expand Up @@ -116,4 +117,95 @@ var _ = Describe("Server", func() {
Expect(tokenizeResp.MaxModelLen).To(Equal(2048))
})
})

Context("SSL/HTTPS Configuration", func() {
It("Should parse SSL certificate configuration correctly", func() {
tempDir := GinkgoT().TempDir()
certFile, keyFile, err := GenerateTempCerts(tempDir)
Expect(err).NotTo(HaveOccurred())

oldArgs := os.Args
defer func() {
os.Args = oldArgs
}()

os.Args = []string{"cmd", "--model", model, "--ssl-certfile", certFile, "--ssl-keyfile", keyFile}
config, err := common.ParseCommandParamsAndLoadConfig()
Expect(err).NotTo(HaveOccurred())
Expect(config.SSLEnabled()).To(BeTrue())
Expect(config.SSLCertFile).To(Equal(certFile))
Expect(config.SSLKeyFile).To(Equal(keyFile))
})

It("Should parse self-signed certificate configuration correctly", func() {
oldArgs := os.Args
defer func() {
os.Args = oldArgs
}()

os.Args = []string{"cmd", "--model", model, "--self-signed-certs"}
config, err := common.ParseCommandParamsAndLoadConfig()
Expect(err).NotTo(HaveOccurred())
Expect(config.SSLEnabled()).To(BeTrue())
Expect(config.SelfSignedCerts).To(BeTrue())
})

It("Should create self-signed TLS certificate successfully", func() {
cert, err := CreateSelfSignedTLSCertificate()
Expect(err).NotTo(HaveOccurred())
Expect(cert.Certificate).To(HaveLen(1))
Expect(cert.PrivateKey).NotTo(BeNil())
})

It("Should validate SSL configuration - both cert and key required", func() {
tempDir := GinkgoT().TempDir()

oldArgs := os.Args
defer func() {
os.Args = oldArgs
}()

certFile, _, err := GenerateTempCerts(tempDir)
Expect(err).NotTo(HaveOccurred())

os.Args = []string{"cmd", "--model", model, "--ssl-certfile", certFile}
_, err = common.ParseCommandParamsAndLoadConfig()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("both ssl-certfile and ssl-keyfile must be provided together"))

_, keyFile, err := GenerateTempCerts(tempDir)
Expect(err).NotTo(HaveOccurred())

os.Args = []string{"cmd", "--model", model, "--ssl-keyfile", keyFile}
_, err = common.ParseCommandParamsAndLoadConfig()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("both ssl-certfile and ssl-keyfile must be provided together"))
})

It("Should start HTTPS server with provided SSL certificates", func(ctx SpecContext) {
tempDir := GinkgoT().TempDir()
certFile, keyFile, err := GenerateTempCerts(tempDir)
Expect(err).NotTo(HaveOccurred())

args := []string{"cmd", "--model", model, "--mode", common.ModeRandom,
"--ssl-certfile", certFile, "--ssl-keyfile", keyFile}
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

resp, err := client.Get("https://localhost/health")
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

It("Should start HTTPS server with self-signed certificates", func(ctx SpecContext) {
args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--self-signed-certs"}
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

resp, err := client.Get("https://localhost/health")
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

})
})
121 changes: 121 additions & 0 deletions pkg/llm-d-inference-sim/server_tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package llmdinferencesim

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"time"

"github.com/valyala/fasthttp"
)

// Based on: https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/8d01161ec48d6b49cd371f179551b35da46e6fd6/internal/tls/tls.go
func (s *VllmSimulator) configureSSL(server *fasthttp.Server) error {
if !s.config.SSLEnabled() {
return nil
}

var cert tls.Certificate
var err error

if s.config.SSLCertFile != "" && s.config.SSLKeyFile != "" {
s.logger.Info("HTTPS server starting with certificate files", "cert", s.config.SSLCertFile, "key", s.config.SSLKeyFile)
cert, err = tls.LoadX509KeyPair(s.config.SSLCertFile, s.config.SSLKeyFile)
} else if s.config.SelfSignedCerts {
s.logger.Info("HTTPS server starting with self-signed certificate")
cert, err = CreateSelfSignedTLSCertificate()
}

if err != nil {
s.logger.Error(err, "failed to create TLS certificate")
return err
}

server.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
},
}

return nil
}

// CreateSelfSignedTLSCertificatePEM creates a self-signed cert and returns the PEM-encoded certificate and key bytes
func CreateSelfSignedTLSCertificatePEM() (certPEM, keyPEM []byte, err error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, fmt.Errorf("error creating serial number: %v", err)
}
now := time.Now()
notBefore := now.UTC()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"llm-d Inference Simulator"},
},
NotBefore: notBefore,
NotAfter: now.Add(time.Hour * 24 * 365 * 10).UTC(), // 10 years
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

priv, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("error generating key: %v", err)
}

derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, fmt.Errorf("error creating certificate: %v", err)
}

certBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})

privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("error marshalling private key: %v", err)
}
keyBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})

return certBytes, keyBytes, nil
}

// CreateSelfSignedTLSCertificate creates a self-signed cert the server can use to serve TLS.
// Original code: https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/8d01161ec48d6b49cd371f179551b35da46e6fd6/internal/tls/tls.go
func CreateSelfSignedTLSCertificate() (tls.Certificate, error) {
certPEM, keyPEM, err := CreateSelfSignedTLSCertificatePEM()
if err != nil {
return tls.Certificate{}, err
}
return tls.X509KeyPair(certPEM, keyPEM)
}
4 changes: 4 additions & 0 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package llmdinferencesim

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -141,6 +142,9 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return listener.Dial()
},
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}, nil
}
Expand Down
Loading