Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/go/getting-started/consumer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ func getTcpOptions() (string, []tcp.Option) {
}

if *tlsEnabled {
tcpOptions = append(tcpOptions, tcp.WithTLS(true))
var tlsOpts []tcp.TLSOption
if *tlsCAFile != "" {
tcpOptions = append(tcpOptions, tcp.WithTLSCAFile(*tlsCAFile))
tlsOpts = append(tlsOpts, tcp.WithTLSCAFile(*tlsCAFile))
}
if *tlsDomain != "" {
tcpOptions = append(tcpOptions, tcp.WithTLSDomain(*tlsDomain))
tlsOpts = append(tlsOpts, tcp.WithTLSDomain(*tlsDomain))
}
tcpOptions = append(tcpOptions, tcp.WithTLS(tlsOpts...))
log.Printf("TLS enabled with CA file: %s, domain: %s", *tlsCAFile, *tlsDomain)
}

Expand Down
7 changes: 4 additions & 3 deletions examples/go/getting-started/producer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,14 @@ func getTcpOptions() (string, []tcp.Option) {
}

if *tlsEnabled {
tcpOptions = append(tcpOptions, tcp.WithTLS(true))
var tlsOpts []tcp.TLSOption
if *tlsCAFile != "" {
tcpOptions = append(tcpOptions, tcp.WithTLSCAFile(*tlsCAFile))
tlsOpts = append(tlsOpts, tcp.WithTLSCAFile(*tlsCAFile))
}
if *tlsDomain != "" {
tcpOptions = append(tcpOptions, tcp.WithTLSDomain(*tlsDomain))
tlsOpts = append(tlsOpts, tcp.WithTLSDomain(*tlsDomain))
}
tcpOptions = append(tcpOptions, tcp.WithTLS(tlsOpts...))
log.Printf("TLS enabled with CA file: %s, domain: %s", *tlsCAFile, *tlsDomain)
}

Expand Down
84 changes: 51 additions & 33 deletions foreign/go/client/tcp/tcp_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,7 @@ type config struct {
serverAddress string
// tlsEnabled indicates whether to use TLS when connecting to the server
tlsEnabled bool
// tlsDomain is the domain to use for TLS when connecting to the server
// If empty, automatically extracts the hostname/IP from serverAddress
tlsDomain string
// tlsCAFile is the path to the CA file to use for TLS
tlsCAFile string
// tlsValidateCertificate indicates whether to validate the server's TLS certificate
tlsValidateCertificate bool
tls tlsConfig
// autoLogin indicates whether to automatically login user after establishing connection.
autoLogin AutoLogin
// reconnection indicates whether to automatically reconnect when disconnected
Expand All @@ -80,14 +74,12 @@ type config struct {

func defaultTcpClientConfig() config {
return config{
serverAddress: "127.0.0.1:8090",
tlsEnabled: false,
tlsDomain: "",
tlsCAFile: "",
tlsValidateCertificate: true,
autoLogin: AutoLogin{},
reconnection: defaultTcpClientReconnectionConfig(),
noDelay: false,
serverAddress: "127.0.0.1:8090",
tlsEnabled: false,
tls: defaultTLSConfig(),
autoLogin: AutoLogin{},
reconnection: defaultTcpClientReconnectionConfig(),
noDelay: false,
}
}

Expand All @@ -107,6 +99,24 @@ func defaultTcpClientReconnectionConfig() tcpClientReconnectionConfig {
}
}

type tlsConfig struct {
// tlsDomain is the domain to use for TLS when connecting to the server
// If empty, automatically extracts the hostname/IP from serverAddress
tlsDomain string
// tlsCAFile is the path to the CA file to use for TLS
tlsCAFile string
// tlsValidateCertificate indicates whether to validate the server's TLS certificate
tlsValidateCertificate bool
}

func defaultTLSConfig() tlsConfig {
return tlsConfig{
tlsDomain: "",
tlsCAFile: "",
tlsValidateCertificate: true,
}
}

type AutoLogin struct {
enabled bool
credentials Credentials
Expand Down Expand Up @@ -145,32 +155,40 @@ func WithServerAddress(address string) Option {
}
}

// WithTLS enables or disables TLS for the TCP client.
func WithTLS(enabled bool) Option {
// TLSOption is a functional option for configuring TLS settings.
type TLSOption func(cfg *tlsConfig)

// WithTLS enables TLS for the TCP client and applies the given TLS options.
func WithTLS(tlsOpts ...TLSOption) Option {
return func(opts *Options) {
opts.config.tlsEnabled = enabled
opts.config.tlsEnabled = true
for _, tlsOpt := range tlsOpts {
if tlsOpt != nil {
tlsOpt(&opts.config.tls)
}
}
}
}

// WithTLSDomain sets the TLS domain for server name indication (SNI).
// If empty, the domain will be automatically extracted from the server address.
func WithTLSDomain(domain string) Option {
return func(opts *Options) {
opts.config.tlsDomain = domain
// If not provided, the domain will be automatically extracted from the server address.
func WithTLSDomain(domain string) TLSOption {
return func(cfg *tlsConfig) {
cfg.tlsDomain = domain
}
}

// WithTLSCAFile sets the path to the CA certificate file for TLS verification.
func WithTLSCAFile(path string) Option {
return func(opts *Options) {
opts.config.tlsCAFile = path
func WithTLSCAFile(path string) TLSOption {
return func(cfg *tlsConfig) {
cfg.tlsCAFile = path
}
}

// WithTLSValidateCertificate enables or disables TLS certificate validation.
func WithTLSValidateCertificate(validate bool) Option {
return func(opts *Options) {
opts.config.tlsValidateCertificate = validate
func WithTLSValidateCertificate(validate bool) TLSOption {
return func(cfg *tlsConfig) {
cfg.tlsValidateCertificate = validate
}
}

Expand Down Expand Up @@ -334,7 +352,7 @@ func (c *IggyTcpClient) connect() error {
attempts = uint(c.config.reconnection.maxRetries)
interval = c.config.reconnection.interval
}
// TODO handle tls logic

var conn net.Conn
if err := retry.New(
retry.Attempts(attempts),
Expand Down Expand Up @@ -392,11 +410,11 @@ func (c *IggyTcpClient) connect() error {

func (c *IggyTcpClient) createTLSConfig() (*tls.Config, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: !c.config.tlsValidateCertificate,
InsecureSkipVerify: !c.config.tls.tlsValidateCertificate,
}

// Set server name for SNI
serverName := c.config.tlsDomain
serverName := c.config.tls.tlsDomain
if serverName == "" {
// Extract hostname from server address (format: "host:port")
host := c.currentServerAddress
Expand All @@ -412,8 +430,8 @@ func (c *IggyTcpClient) createTLSConfig() (*tls.Config, error) {
tlsConfig.ServerName = serverName

// Load CA certificate if provided
if c.config.tlsCAFile != "" {
caCert, err := os.ReadFile(c.config.tlsCAFile)
if c.config.tls.tlsCAFile != "" {
caCert, err := os.ReadFile(c.config.tls.tlsCAFile)
if err != nil {
return nil, ierror.ErrInvalidTlsCertificatePath
}
Expand Down
15 changes: 8 additions & 7 deletions foreign/go/tests/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ func TestTCPTLSConnection_WithCA_Success(t *testing.T) {
cli, err := client.NewIggyClient(
client.WithTcp(
tcp.WithServerAddress(connectAddr),
tcp.WithTLS(true),
tcp.WithTLSCAFile(caFile),
tcp.WithTLSDomain("localhost"),
tcp.WithTLS(
tcp.WithTLSCAFile(caFile),
tcp.WithTLSDomain("localhost"),
),
),
)
require.NoError(t, err, "Failed to create TLS client")
Expand All @@ -183,7 +184,6 @@ func TestTCPTLSConnection_WithoutTLS_Failure(t *testing.T) {
cli, err := client.NewIggyClient(
client.WithTcp(
tcp.WithServerAddress(connectAddr),
tcp.WithTLS(false),
),
)

Expand Down Expand Up @@ -211,9 +211,10 @@ func TestTCPTLSConnection_MessageFlow_Success(t *testing.T) {
cli, err := client.NewIggyClient(
client.WithTcp(
tcp.WithServerAddress(connectAddr),
tcp.WithTLS(true),
tcp.WithTLSCAFile(caFile),
tcp.WithTLSDomain("localhost"),
tcp.WithTLS(
tcp.WithTLSCAFile(caFile),
tcp.WithTLSDomain("localhost"),
),
),
)
require.NoError(t, err, "Failed to create TLS client")
Expand Down
Loading