diff --git a/examples/go/getting-started/consumer/main.go b/examples/go/getting-started/consumer/main.go index 4ab6a750b4..b758b28779 100644 --- a/examples/go/getting-started/consumer/main.go +++ b/examples/go/getting-started/consumer/main.go @@ -138,13 +138,14 @@ func getTcpOptions() (string, []tcp.Option) { } if *tlsEnabled { - tcpOptions = append(tcpOptions, tcp.WithTLS(true)) + var tlsOpts []tcp.TLSOption if *tlsCAFile != "" { - tcpOptions = append(tcpOptions, tcp.WithTLSCAFile(*tlsCAFile)) + tlsOpts = append(tlsOpts, tcp.WithTLSCAFile(*tlsCAFile)) } if *tlsDomain != "" { - tcpOptions = append(tcpOptions, tcp.WithTLSDomain(*tlsDomain)) + tlsOpts = append(tlsOpts, tcp.WithTLSDomain(*tlsDomain)) } + tcpOptions = append(tcpOptions, tcp.WithTLS(tlsOpts...)) log.Printf("TLS enabled with CA file: %s, domain: %s", *tlsCAFile, *tlsDomain) } diff --git a/examples/go/getting-started/producer/main.go b/examples/go/getting-started/producer/main.go index 5b40831d59..5aa2a2807e 100644 --- a/examples/go/getting-started/producer/main.go +++ b/examples/go/getting-started/producer/main.go @@ -142,13 +142,14 @@ func getTcpOptions() (string, []tcp.Option) { } if *tlsEnabled { - tcpOptions = append(tcpOptions, tcp.WithTLS(true)) + var tlsOpts []tcp.TLSOption if *tlsCAFile != "" { - tcpOptions = append(tcpOptions, tcp.WithTLSCAFile(*tlsCAFile)) + tlsOpts = append(tlsOpts, tcp.WithTLSCAFile(*tlsCAFile)) } if *tlsDomain != "" { - tcpOptions = append(tcpOptions, tcp.WithTLSDomain(*tlsDomain)) + tlsOpts = append(tlsOpts, tcp.WithTLSDomain(*tlsDomain)) } + tcpOptions = append(tcpOptions, tcp.WithTLS(tlsOpts...)) log.Printf("TLS enabled with CA file: %s, domain: %s", *tlsCAFile, *tlsDomain) } diff --git a/foreign/go/client/tcp/tcp_core.go b/foreign/go/client/tcp/tcp_core.go index 796a524ba0..91f73f2477 100644 --- a/foreign/go/client/tcp/tcp_core.go +++ b/foreign/go/client/tcp/tcp_core.go @@ -63,13 +63,7 @@ type config struct { serverAddress string // tlsEnabled indicates whether to use TLS when connecting to the server tlsEnabled bool - // tlsDomain is the domain to use for TLS when connecting to the server - // If empty, automatically extracts the hostname/IP from serverAddress - tlsDomain string - // tlsCAFile is the path to the CA file to use for TLS - tlsCAFile string - // tlsValidateCertificate indicates whether to validate the server's TLS certificate - tlsValidateCertificate bool + tls tlsConfig // autoLogin indicates whether to automatically login user after establishing connection. autoLogin AutoLogin // reconnection indicates whether to automatically reconnect when disconnected @@ -80,14 +74,12 @@ type config struct { func defaultTcpClientConfig() config { return config{ - serverAddress: "127.0.0.1:8090", - tlsEnabled: false, - tlsDomain: "", - tlsCAFile: "", - tlsValidateCertificate: true, - autoLogin: AutoLogin{}, - reconnection: defaultTcpClientReconnectionConfig(), - noDelay: false, + serverAddress: "127.0.0.1:8090", + tlsEnabled: false, + tls: defaultTLSConfig(), + autoLogin: AutoLogin{}, + reconnection: defaultTcpClientReconnectionConfig(), + noDelay: false, } } @@ -107,6 +99,24 @@ func defaultTcpClientReconnectionConfig() tcpClientReconnectionConfig { } } +type tlsConfig struct { + // tlsDomain is the domain to use for TLS when connecting to the server + // If empty, automatically extracts the hostname/IP from serverAddress + tlsDomain string + // tlsCAFile is the path to the CA file to use for TLS + tlsCAFile string + // tlsValidateCertificate indicates whether to validate the server's TLS certificate + tlsValidateCertificate bool +} + +func defaultTLSConfig() tlsConfig { + return tlsConfig{ + tlsDomain: "", + tlsCAFile: "", + tlsValidateCertificate: true, + } +} + type AutoLogin struct { enabled bool credentials Credentials @@ -145,32 +155,40 @@ func WithServerAddress(address string) Option { } } -// WithTLS enables or disables TLS for the TCP client. -func WithTLS(enabled bool) Option { +// TLSOption is a functional option for configuring TLS settings. +type TLSOption func(cfg *tlsConfig) + +// WithTLS enables TLS for the TCP client and applies the given TLS options. +func WithTLS(tlsOpts ...TLSOption) Option { return func(opts *Options) { - opts.config.tlsEnabled = enabled + opts.config.tlsEnabled = true + for _, tlsOpt := range tlsOpts { + if tlsOpt != nil { + tlsOpt(&opts.config.tls) + } + } } } // WithTLSDomain sets the TLS domain for server name indication (SNI). -// If empty, the domain will be automatically extracted from the server address. -func WithTLSDomain(domain string) Option { - return func(opts *Options) { - opts.config.tlsDomain = domain +// If not provided, the domain will be automatically extracted from the server address. +func WithTLSDomain(domain string) TLSOption { + return func(cfg *tlsConfig) { + cfg.tlsDomain = domain } } // WithTLSCAFile sets the path to the CA certificate file for TLS verification. -func WithTLSCAFile(path string) Option { - return func(opts *Options) { - opts.config.tlsCAFile = path +func WithTLSCAFile(path string) TLSOption { + return func(cfg *tlsConfig) { + cfg.tlsCAFile = path } } // WithTLSValidateCertificate enables or disables TLS certificate validation. -func WithTLSValidateCertificate(validate bool) Option { - return func(opts *Options) { - opts.config.tlsValidateCertificate = validate +func WithTLSValidateCertificate(validate bool) TLSOption { + return func(cfg *tlsConfig) { + cfg.tlsValidateCertificate = validate } } @@ -334,7 +352,7 @@ func (c *IggyTcpClient) connect() error { attempts = uint(c.config.reconnection.maxRetries) interval = c.config.reconnection.interval } - // TODO handle tls logic + var conn net.Conn if err := retry.New( retry.Attempts(attempts), @@ -392,11 +410,11 @@ func (c *IggyTcpClient) connect() error { func (c *IggyTcpClient) createTLSConfig() (*tls.Config, error) { tlsConfig := &tls.Config{ - InsecureSkipVerify: !c.config.tlsValidateCertificate, + InsecureSkipVerify: !c.config.tls.tlsValidateCertificate, } // Set server name for SNI - serverName := c.config.tlsDomain + serverName := c.config.tls.tlsDomain if serverName == "" { // Extract hostname from server address (format: "host:port") host := c.currentServerAddress @@ -412,8 +430,8 @@ func (c *IggyTcpClient) createTLSConfig() (*tls.Config, error) { tlsConfig.ServerName = serverName // Load CA certificate if provided - if c.config.tlsCAFile != "" { - caCert, err := os.ReadFile(c.config.tlsCAFile) + if c.config.tls.tlsCAFile != "" { + caCert, err := os.ReadFile(c.config.tls.tlsCAFile) if err != nil { return nil, ierror.ErrInvalidTlsCertificatePath } diff --git a/foreign/go/tests/tls_test.go b/foreign/go/tests/tls_test.go index 701d2e1549..c3525c9ddf 100644 --- a/foreign/go/tests/tls_test.go +++ b/foreign/go/tests/tls_test.go @@ -155,9 +155,10 @@ func TestTCPTLSConnection_WithCA_Success(t *testing.T) { cli, err := client.NewIggyClient( client.WithTcp( tcp.WithServerAddress(connectAddr), - tcp.WithTLS(true), - tcp.WithTLSCAFile(caFile), - tcp.WithTLSDomain("localhost"), + tcp.WithTLS( + tcp.WithTLSCAFile(caFile), + tcp.WithTLSDomain("localhost"), + ), ), ) require.NoError(t, err, "Failed to create TLS client") @@ -183,7 +184,6 @@ func TestTCPTLSConnection_WithoutTLS_Failure(t *testing.T) { cli, err := client.NewIggyClient( client.WithTcp( tcp.WithServerAddress(connectAddr), - tcp.WithTLS(false), ), ) @@ -211,9 +211,10 @@ func TestTCPTLSConnection_MessageFlow_Success(t *testing.T) { cli, err := client.NewIggyClient( client.WithTcp( tcp.WithServerAddress(connectAddr), - tcp.WithTLS(true), - tcp.WithTLSCAFile(caFile), - tcp.WithTLSDomain("localhost"), + tcp.WithTLS( + tcp.WithTLSCAFile(caFile), + tcp.WithTLSDomain("localhost"), + ), ), ) require.NoError(t, err, "Failed to create TLS client")