Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion net/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func TestConnWriteWithContext(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tcpConn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
require.NoError(t, err)
c := NewConn(tcpConn)
defer func() {
Expand Down
19 changes: 12 additions & 7 deletions net/tlslistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ func TestTLSListenerAcceptWithContext(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)

c, err := tls.DialWithDialer(&net.Dialer{
Timeout: time.Millisecond * 400,
}, "tcp", listener.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
})
d := &tls.Dialer{
NetDialer: &net.Dialer{
Timeout: time.Millisecond * 400,
},
Config: &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
},
}
c, err := d.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
continue
}
Expand Down Expand Up @@ -186,7 +190,8 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)
func() {
conn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
return
}
Expand Down
16 changes: 15 additions & 1 deletion options/tcpOptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package options

import (
"crypto/tls"
"time"

tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client"
tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server"
Expand Down Expand Up @@ -34,11 +35,24 @@ func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.DisableTCPSignalMessageCSM = true
}

// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created.
func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt {
return DisableTCPSignalMessageCSMOpt{}
}

type CSMExchangeTimeoutOpt struct {
timeout time.Duration
}

func (o CSMExchangeTimeoutOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.CSMExchangeTimeout = o.timeout
}

func WithCSMExchangeTimeout(timeout time.Duration) CSMExchangeTimeoutOpt {
return CSMExchangeTimeoutOpt{
timeout: timeout,
}
}

// TLSOpt tls configuration option.
type TLSOpt struct {
tlsCfg *tls.Config
Expand Down
39 changes: 35 additions & 4 deletions tcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/blockwise"
Expand All @@ -30,19 +31,23 @@ func Dial(target string, opts ...Option) (*client.Conn, error) {
var conn net.Conn
var err error
if cfg.TLSCfg != nil {
conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg)
d := &tls.Dialer{
NetDialer: cfg.Dialer,
Config: cfg.TLSCfg,
}
conn, err = d.DialContext(cfg.Ctx, cfg.Net, target)
} else {
conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target)
}
if err != nil {
return nil, err
}
opts = append(opts, options.WithCloseSocket())
return Client(conn, opts...), nil
return Client(conn, opts...)
}

// Client creates client over tcp/tcp-tls connection.
func Client(conn net.Conn, opts ...Option) *client.Conn {
func Client(conn net.Conn, opts ...Option) (*client.Conn, error) {
cfg := client.DefaultConfig
for _, o := range opts {
o.TCPClientApply(&cfg)
Expand Down Expand Up @@ -100,12 +105,38 @@ func Client(conn net.Conn, opts ...Option) *client.Conn {
return cc.Context().Err() == nil
})

var csmExchangeDone chan struct{}
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})

cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
close(csmExchangeDone)
}
})
}

go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()

return cc
// if CSM messages are enabled, wait for the CSM messages to be exchanged
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
select {
case <-time.After(cfg.CSMExchangeTimeout):
err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
cfg.Errors(err)
cc.Close() // Close connection on timeout
return nil, err // or return cc with an error state
case <-csmExchangeDone:
// CSM exchange completed successfully
}
// Clear the handler after exchange is complete or timed out
cc.SetTCPSignalReceivedHandler(nil)
}

return cc, nil
}
1 change: 1 addition & 0 deletions tcp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ type Config struct {
DisablePeerTCPSignalMessageCSMs bool
CloseSocket bool
DisableTCPSignalMessageCSM bool
CSMExchangeTimeout time.Duration
}
40 changes: 35 additions & 5 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/plgd-dev/go-coap/v3/message"
Expand All @@ -27,6 +28,8 @@ type InactivityMonitor interface {
CheckInactivity(now time.Time, cc *Conn)
}

type TCPSignalReceivedHandler func(codes.Code)

type (
HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message)
ErrorFunc = func(error)
Expand All @@ -51,6 +54,8 @@ type Conn struct {
blockwiseSZX blockwise.SZX
peerMaxMessageSize atomic.Uint32
disablePeerTCPSignalMessageCSMs bool
tcpSignalReceivedHandler TCPSignalReceivedHandler
handlerMutex sync.RWMutex
peerBlockWiseTranferEnabled atomic.Bool

receivedMessageReader *client.ReceivedMessageReader[*Conn]
Expand Down Expand Up @@ -267,6 +272,12 @@ func (cc *Conn) Run() (err error) {
return cc.session.Run(cc)
}

func (cc *Conn) SetTCPSignalReceivedHandler(handler TCPSignalReceivedHandler) {
cc.handlerMutex.Lock()
defer cc.handlerMutex.Unlock()
cc.tcpSignalReceivedHandler = handler
}

// AddOnClose calls function on close connection event.
func (cc *Conn) AddOnClose(f EventFunc) {
cc.session.AddOnClose(f)
Expand Down Expand Up @@ -370,6 +381,14 @@ func (cc *Conn) sendPong(token message.Token) error {
return cc.Session().WriteMessage(req)
}

func (cc *Conn) handleTCPSignalReceived(code codes.Code) {
cc.handlerMutex.RLock()
defer cc.handlerMutex.RUnlock()
if cc.tcpSignalReceivedHandler != nil {
cc.tcpSignalReceivedHandler(code)
}
}

func (cc *Conn) handleSignals(r *pool.Message) bool {
switch r.Code() {
case codes.CSM:
Expand All @@ -382,6 +401,9 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if r.HasOption(message.TCPBlockWiseTransfer) {
cc.peerBlockWiseTranferEnabled.Store(true)
}

// signal CSM message is received.
cc.handleTCPSignalReceived(codes.CSM)
return true
case codes.Ping:
// if r.HasOption(message.TCPCustody) {
Expand All @@ -390,21 +412,29 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) {
cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err))
}

cc.handleTCPSignalReceived(codes.Ping)
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Pong)
return true
case codes.Release:
// if r.HasOption(message.TCPAlternativeAddress) {
// TODO
// }

cc.handleTCPSignalReceived(codes.Release)
return true
case codes.Abort:
// if r.HasOption(message.TCPBadCSMOption) {
// TODO
// }
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Abort)
return true
}
return false
Expand Down
75 changes: 75 additions & 0 deletions tcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/plgd-dev/go-coap/v3/options/config"
"github.com/plgd-dev/go-coap/v3/pkg/runner/periodic"
"github.com/plgd-dev/go-coap/v3/tcp/client"
"github.com/plgd-dev/go-coap/v3/tcp/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -839,3 +840,77 @@ func TestConnRequestMonitorDropRequest(t *testing.T) {
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestConnWithCSMExchangeTimeout(t *testing.T) {
type args struct {
clientOptions []Option
serverOptions []server.Option
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "client-server-no-csm",
args: args{
clientOptions: []Option{},
serverOptions: []server.Option{},
},
wantErr: false,
},
{
name: "client-server-csm-success",
args: args{
clientOptions: []Option{
options.WithCSMExchangeTimeout(time.Second * 3),
},
},
wantErr: false,
},
{
name: "client-server-csm-timeout",
args: args{
clientOptions: []Option{
options.WithCSMExchangeTimeout(time.Second * 3),
},
serverOptions: []server.Option{
options.WithDisableTCPSignalMessageCSM(),
},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, err := coapNet.NewTCPListener("tcp", "")
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

s := NewServer(tt.args.serverOptions...)
defer s.Stop()
wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
assert.NoError(t, errS)
}()

client, err := Dial(l.Addr().String(),
tt.args.clientOptions...)
if tt.wantErr {
require.Nil(t, client)
require.Error(t, err)
} else {
require.NotNil(t, client)
require.NoError(t, err)
}
})
}
}
3 changes: 2 additions & 1 deletion tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ func TestServerKeepAliveMonitor(t *testing.T) {
assert.NoError(t, errS)
}()

cc, err := net.Dial("tcp", ld.Addr().String())
dialer := net.Dialer{}
cc, err := dialer.DialContext(context.Background(), "tcp", ld.Addr().String())
require.NoError(t, err)
defer func() {
_ = cc.Close()
Expand Down
Loading