Skip to content

Commit b5ed2cc

Browse files
committed
feat: adds context support
1 parent b2960c6 commit b5ed2cc

File tree

2 files changed

+84
-11
lines changed

2 files changed

+84
-11
lines changed

sctp.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ package sctp
1717

1818
import (
1919
"bytes"
20+
"context"
2021
"encoding/binary"
2122
"fmt"
23+
"log"
2224
"net"
2325
"strconv"
2426
"strings"
@@ -437,6 +439,7 @@ func SCTPBind(fd int, addr *SCTPAddr, flags int) error {
437439
type SCTPConn struct {
438440
_fd int32
439441
notificationHandler NotificationHandler
442+
ctx context.Context
440443
}
441444

442445
func (c *SCTPConn) fd() int {
@@ -451,6 +454,15 @@ func NewSCTPConn(fd int, handler NotificationHandler) *SCTPConn {
451454
return conn
452455
}
453456

457+
func dialSCTPConnect(sock int, raddr *SCTPAddr, notificationHandler NotificationHandler) (*SCTPConn, error) {
458+
_, err := SCTPConnect(sock, raddr)
459+
if err != nil {
460+
return nil, err
461+
}
462+
log.Println("4-CANCELLEDDDDDDDD")
463+
return NewSCTPConn(sock, notificationHandler), nil
464+
}
465+
454466
func (c *SCTPConn) Write(b []byte) (int, error) {
455467
return c.SCTPWrite(b, nil)
456468
}

sctp_linux.go

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
package sctp
2020

2121
import (
22+
"context"
2223
"io"
24+
"log"
2325
"net"
2426
"os"
2527
"runtime"
@@ -105,7 +107,24 @@ func (c *SCTPConn) SCTPWrite(b []byte, info *SndRcvInfo) (int, error) {
105107
hdr.SetLen(syscall.CmsgSpace(len(cmsgBuf)))
106108
cbuf = append(toBuf(hdr), cmsgBuf...)
107109
}
108-
return syscall.SendmsgN(c.fd(), b, cbuf, nil, syscall.MSG_DONTWAIT)
110+
111+
if c.ctx != nil {
112+
for {
113+
select {
114+
case <-c.ctx.Done():
115+
c.Close()
116+
return 0, syscall.EBADF
117+
default:
118+
n, err := syscall.SendmsgN(c.fd(), b, cbuf, nil, syscall.MSG_DONTWAIT)
119+
if err != nil {
120+
continue
121+
}
122+
return n, err
123+
}
124+
}
125+
}
126+
127+
return syscall.SendmsgN(c.fd(), b, cbuf, nil, 0)
109128
}
110129

111130
func parseSndRcvInfo(b []byte) (*SndRcvInfo, error) {
@@ -354,21 +373,67 @@ func DialSCTP(net string, laddr, raddr *SCTPAddr) (*SCTPConn, error) {
354373
return DialSCTPExt(net, laddr, raddr, InitMsg{NumOstreams: SCTP_MAX_STREAM})
355374
}
356375

376+
// DialSCTPWithContext - bind socket to laddr (if given) and connect to raddr with context
377+
func DialSCTPWithContext(net string, laddr, raddr *SCTPAddr, ctx context.Context) (*SCTPConn, error) {
378+
return dialSCTPExtConfigWithContext(net, laddr, raddr, InitMsg{NumOstreams: SCTP_MAX_STREAM}, nil, nil, ctx)
379+
}
380+
357381
// DialSCTPExt - same as DialSCTP but with given SCTP options
358382
func DialSCTPExt(network string, laddr, raddr *SCTPAddr, options InitMsg) (*SCTPConn, error) {
359383
return dialSCTPExtConfig(network, laddr, raddr, options, nil, nil)
360384
}
361385

362386
// dialSCTPExtConfig - same as DialSCTP but with given SCTP options and socket configuration
363387
func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error, notificationHandler NotificationHandler) (*SCTPConn, error) {
388+
sock, err := clientSocket(network, laddr, raddr, options, control)
389+
if err != nil {
390+
return nil, err
391+
}
392+
393+
return dialSCTPConnect(sock, raddr, notificationHandler)
394+
}
395+
396+
// dialSCTPExtConfig - same as DialSCTP but with given SCTP options and socket configuration
397+
func dialSCTPExtConfigWithContext(network string, laddr, raddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error, notificationHandler NotificationHandler, ctx context.Context) (*SCTPConn, error) {
398+
sock, err := clientSocket(network, laddr, raddr, options, control)
399+
if err != nil {
400+
return nil, err
401+
}
402+
403+
connChan := make(chan *SCTPConn, 1)
404+
errChan := make(chan error, 1)
405+
done := make(chan struct{})
406+
407+
go func() {
408+
connection, errDial := dialSCTPConnect(sock, raddr, notificationHandler)
409+
if errDial == nil {
410+
connection.ctx = ctx
411+
}
412+
connChan <- connection
413+
errChan <- errDial
414+
close(done)
415+
}()
416+
417+
select {
418+
case <-ctx.Done():
419+
log.Println("2-CANCELLEDDDDDDDD")
420+
closeSctpSocket(sock, 1*time.Second)
421+
log.Println("3-CANCELLEDDDDDDDD")
422+
return nil, ctx.Err()
423+
case <-done:
424+
return <-connChan, <-errChan
425+
}
426+
}
427+
428+
func clientSocket(network string, laddr, raddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error) (int, error) {
364429
af, ipv6only := favoriteAddrFamily(network, laddr, raddr, "dial")
365430
sock, err := syscall.Socket(
366431
af,
367432
syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC,
368433
syscall.IPPROTO_SCTP,
369434
)
370435
if err != nil {
371-
return nil, err
436+
return -1, err
372437
}
373438

374439
// close socket on error
@@ -378,7 +443,7 @@ func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg,
378443
}
379444
}()
380445
if err = setDefaultSockopts(sock, af, ipv6only); err != nil {
381-
return nil, err
446+
return -1, err
382447
}
383448
if control != nil {
384449
rc := rawConn{sockfd: sock}
@@ -387,12 +452,12 @@ func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg,
387452
localAddressString = laddr.String()
388453
}
389454
if err = control(network, localAddressString, rc); err != nil {
390-
return nil, err
455+
return -1, err
391456
}
392457
}
393458
err = setInitOpts(sock, options)
394459
if err != nil {
395-
return nil, err
460+
return -1, err
396461
}
397462
if laddr != nil {
398463
// If IP address and/or port was not provided so far, let's use the unspecified IPv4 or IPv6 address
@@ -405,12 +470,8 @@ func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg,
405470
}
406471
err = SCTPBind(sock, laddr, SCTP_BINDX_ADD_ADDR) // error EADDRINUSE "Address already in use" may occur if resource (source IP and Port) is occupied
407472
if err != nil {
408-
return nil, err
473+
return -1, err
409474
}
410475
}
411-
_, err = SCTPConnect(sock, raddr)
412-
if err != nil {
413-
return nil, err
414-
}
415-
return NewSCTPConn(sock, notificationHandler), nil
476+
return sock, nil
416477
}

0 commit comments

Comments
 (0)