Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion net/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func TestConnWriteWithContext(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tcpConn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
require.NoError(t, err)
c := NewConn(tcpConn)
defer func() {
Expand Down
19 changes: 12 additions & 7 deletions net/tlslistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ func TestTLSListenerAcceptWithContext(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)

c, err := tls.DialWithDialer(&net.Dialer{
Timeout: time.Millisecond * 400,
}, "tcp", listener.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
})
d := &tls.Dialer{
NetDialer: &net.Dialer{
Timeout: time.Millisecond * 400,
},
Config: &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
},
}
c, err := d.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
continue
}
Expand Down Expand Up @@ -186,7 +190,8 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)
func() {
conn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
return
}
Expand Down
16 changes: 15 additions & 1 deletion options/tcpOptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package options

import (
"crypto/tls"
"time"

tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client"
tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server"
Expand Down Expand Up @@ -34,11 +35,24 @@ func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.DisableTCPSignalMessageCSM = true
}

// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created.
func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt {
return DisableTCPSignalMessageCSMOpt{}
}

type CSMExchangeTimeoutOpt struct {
timeout time.Duration
}

func (o CSMExchangeTimeoutOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.CSMExchangeTimeout = o.timeout
}

func WithCSMExchangeTimeout(timeout time.Duration) CSMExchangeTimeoutOpt {
return CSMExchangeTimeoutOpt{
timeout: timeout,
}
}

// TLSOpt tls configuration option.
type TLSOpt struct {
tlsCfg *tls.Config
Expand Down
39 changes: 35 additions & 4 deletions tcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/blockwise"
Expand All @@ -30,19 +31,23 @@ func Dial(target string, opts ...Option) (*client.Conn, error) {
var conn net.Conn
var err error
if cfg.TLSCfg != nil {
conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg)
d := &tls.Dialer{
NetDialer: cfg.Dialer,
Config: cfg.TLSCfg,
}
conn, err = d.DialContext(cfg.Ctx, cfg.Net, target)
} else {
conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target)
}
if err != nil {
return nil, err
}
opts = append(opts, options.WithCloseSocket())
return Client(conn, opts...), nil
return Client(conn, opts...)
}

// Client creates client over tcp/tcp-tls connection.
func Client(conn net.Conn, opts ...Option) *client.Conn {
func Client(conn net.Conn, opts ...Option) (*client.Conn, error) {
cfg := client.DefaultConfig
for _, o := range opts {
o.TCPClientApply(&cfg)
Expand Down Expand Up @@ -100,12 +105,38 @@ func Client(conn net.Conn, opts ...Option) *client.Conn {
return cc.Context().Err() == nil
})

var csmExchangeDone chan struct{}
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})

cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
close(csmExchangeDone)
}
})
}

go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()

return cc
// if CSM messages are enabled, wait for the CSM messages to be exchanged
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
select {
case <-time.After(cfg.CSMExchangeTimeout):
err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
cfg.Errors(err)
cc.Close() // Close connection on timeout
return nil, err // or return cc with an error state
case <-csmExchangeDone:
// CSM exchange completed successfully
}
// Clear the handler after exchange is complete or timed out
cc.SetTCPSignalReceivedHandler(nil)
}

return cc, nil
}
1 change: 1 addition & 0 deletions tcp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ type Config struct {
DisablePeerTCPSignalMessageCSMs bool
CloseSocket bool
DisableTCPSignalMessageCSM bool
CSMExchangeTimeout time.Duration
}
40 changes: 35 additions & 5 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/plgd-dev/go-coap/v3/message"
Expand All @@ -27,6 +28,8 @@ type InactivityMonitor interface {
CheckInactivity(now time.Time, cc *Conn)
}

type TCPSignalReceivedHandler func(codes.Code)

type (
HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message)
ErrorFunc = func(error)
Expand All @@ -51,6 +54,8 @@ type Conn struct {
blockwiseSZX blockwise.SZX
peerMaxMessageSize atomic.Uint32
disablePeerTCPSignalMessageCSMs bool
tcpSignalReceivedHandler TCPSignalReceivedHandler
handlerMutex sync.RWMutex
peerBlockWiseTranferEnabled atomic.Bool

receivedMessageReader *client.ReceivedMessageReader[*Conn]
Expand Down Expand Up @@ -267,6 +272,12 @@ func (cc *Conn) Run() (err error) {
return cc.session.Run(cc)
}

func (cc *Conn) SetTCPSignalReceivedHandler(handler TCPSignalReceivedHandler) {
cc.handlerMutex.Lock()
defer cc.handlerMutex.Unlock()
cc.tcpSignalReceivedHandler = handler
}

// AddOnClose calls function on close connection event.
func (cc *Conn) AddOnClose(f EventFunc) {
cc.session.AddOnClose(f)
Expand Down Expand Up @@ -370,6 +381,14 @@ func (cc *Conn) sendPong(token message.Token) error {
return cc.Session().WriteMessage(req)
}

func (cc *Conn) handleTCPSignalReceived(code codes.Code) {
cc.handlerMutex.RLock()
defer cc.handlerMutex.RUnlock()
if cc.tcpSignalReceivedHandler != nil {
cc.tcpSignalReceivedHandler(code)
}
}

func (cc *Conn) handleSignals(r *pool.Message) bool {
switch r.Code() {
case codes.CSM:
Expand All @@ -382,6 +401,9 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if r.HasOption(message.TCPBlockWiseTransfer) {
cc.peerBlockWiseTranferEnabled.Store(true)
}

// signal CSM message is received.
cc.handleTCPSignalReceived(codes.CSM)
return true
case codes.Ping:
// if r.HasOption(message.TCPCustody) {
Expand All @@ -390,21 +412,29 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) {
cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err))
}

cc.handleTCPSignalReceived(codes.Ping)
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Pong)
return true
case codes.Release:
// if r.HasOption(message.TCPAlternativeAddress) {
// TODO
// }

cc.handleTCPSignalReceived(codes.Release)
return true
case codes.Abort:
// if r.HasOption(message.TCPBadCSMOption) {
// TODO
// }
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Abort)
return true
}
return false
Expand Down
78 changes: 78 additions & 0 deletions tcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/plgd-dev/go-coap/v3/options/config"
"github.com/plgd-dev/go-coap/v3/pkg/runner/periodic"
"github.com/plgd-dev/go-coap/v3/tcp/client"
"github.com/plgd-dev/go-coap/v3/tcp/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -839,3 +840,80 @@ func TestConnRequestMonitorDropRequest(t *testing.T) {
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestConnWithCSMExchangeTimeout(t *testing.T) {
type args struct {
clientOptions []Option
serverOptions []server.Option
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "client-server-no-csm",
args: args{
clientOptions: []Option{},
serverOptions: []server.Option{},
},
wantErr: false,
},
{
name: "client-server-csm-success",
args: args{
clientOptions: []Option{
options.WithCSMExchangeTimeout(time.Second * 3),
},
},
wantErr: false,
},
{
name: "client-server-csm-timeout",
args: args{
clientOptions: []Option{
options.WithCSMExchangeTimeout(time.Second * 3),
},
serverOptions: []server.Option{
options.WithDisableTCPSignalMessageCSM(),
},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, err := coapNet.NewTCPListener("tcp", "")
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

s := NewServer(tt.args.serverOptions...)
defer s.Stop()
wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
assert.NoError(t, errS)
}()

cc, err := Dial(l.Addr().String(), tt.args.clientOptions...)
if tt.wantErr {
require.Nil(t, cc)
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, cc)
defer func() {
_ = cc.Close()
<-cc.Done()
}()
})
}
}
3 changes: 2 additions & 1 deletion tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ func TestServerKeepAliveMonitor(t *testing.T) {
assert.NoError(t, errS)
}()

cc, err := net.Dial("tcp", ld.Addr().String())
dialer := net.Dialer{}
cc, err := dialer.DialContext(context.Background(), "tcp", ld.Addr().String())
require.NoError(t, err)
defer func() {
_ = cc.Close()
Expand Down
Loading