Skip to content

Commit 2b804ab

Browse files
fraenkelseankhliao
authored andcommitted
net: context aware Dialer.Dial functions
Add context aware dial functions for TCP, UDP, IP and Unix networks. Fixes golang#49097 Updates golang#59897 Change-Id: I7523452e8e463a587a852e0555cec822d8dcb3dd Reviewed-on: https://go-review.googlesource.com/c/go/+/490975 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Dmitri Shuralyov <[email protected]> Reviewed-by: David Chase <[email protected]> Reviewed-by: Sean Liao <[email protected]>
1 parent 6abfe7b commit 2b804ab

File tree

8 files changed

+234
-29
lines changed

8 files changed

+234
-29
lines changed

api/next/49097.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
pkg net, method (*Dialer) DialIP(context.Context, string, netip.Addr, netip.Addr) (*IPConn, error) #49097
2+
pkg net, method (*Dialer) DialTCP(context.Context, string, netip.AddrPort, netip.AddrPort) (*TCPConn, error) #49097
3+
pkg net, method (*Dialer) DialUDP(context.Context, string, netip.AddrPort, netip.AddrPort) (*UDPConn, error) #49097
4+
pkg net, method (*Dialer) DialUnix(context.Context, string, *UnixAddr, *UnixAddr) (*UnixConn, error) #49097
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added context aware dial functions for TCP, UDP, IP and Unix networks.

src/net/dial.go

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"internal/bytealg"
1010
"internal/godebug"
1111
"internal/nettrace"
12+
"net/netip"
1213
"syscall"
1314
"time"
1415
)
@@ -523,30 +524,8 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
523524
// See func [Dial] for a description of the network and address
524525
// parameters.
525526
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
526-
if ctx == nil {
527-
panic("nil context")
528-
}
529-
deadline := d.deadline(ctx, time.Now())
530-
if !deadline.IsZero() {
531-
testHookStepTime()
532-
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
533-
subCtx, cancel := context.WithDeadline(ctx, deadline)
534-
defer cancel()
535-
ctx = subCtx
536-
}
537-
}
538-
if oldCancel := d.Cancel; oldCancel != nil {
539-
subCtx, cancel := context.WithCancel(ctx)
540-
defer cancel()
541-
go func() {
542-
select {
543-
case <-oldCancel:
544-
cancel()
545-
case <-subCtx.Done():
546-
}
547-
}()
548-
ctx = subCtx
549-
}
527+
ctx, cancel := d.dialCtx(ctx)
528+
defer cancel()
550529

551530
// Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
552531
resolveCtx := ctx
@@ -578,6 +557,97 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
578557
return sd.dialParallel(ctx, primaries, fallbacks)
579558
}
580559

560+
func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
561+
if ctx == nil {
562+
panic("nil context")
563+
}
564+
deadline := d.deadline(ctx, time.Now())
565+
var cancel1, cancel2 context.CancelFunc
566+
if !deadline.IsZero() {
567+
testHookStepTime()
568+
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
569+
var subCtx context.Context
570+
subCtx, cancel1 = context.WithDeadline(ctx, deadline)
571+
ctx = subCtx
572+
}
573+
}
574+
if oldCancel := d.Cancel; oldCancel != nil {
575+
subCtx, cancel2 := context.WithCancel(ctx)
576+
go func() {
577+
select {
578+
case <-oldCancel:
579+
cancel2()
580+
case <-subCtx.Done():
581+
}
582+
}()
583+
ctx = subCtx
584+
}
585+
return ctx, func() {
586+
if cancel1 != nil {
587+
cancel1()
588+
}
589+
if cancel2 != nil {
590+
cancel2()
591+
}
592+
}
593+
}
594+
595+
// DialTCP acts like Dial for TCP networks using the provided context.
596+
//
597+
// The provided Context must be non-nil. If the context expires before
598+
// the connection is complete, an error is returned. Once successfully
599+
// connected, any expiration of the context will not affect the
600+
// connection.
601+
//
602+
// The network must be a TCP network name; see func Dial for details.
603+
func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
604+
ctx, cancel := d.dialCtx(ctx)
605+
defer cancel()
606+
return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
607+
}
608+
609+
// DialUDP acts like Dial for UDP networks using the provided context.
610+
//
611+
// The provided Context must be non-nil. If the context expires before
612+
// the connection is complete, an error is returned. Once successfully
613+
// connected, any expiration of the context will not affect the
614+
// connection.
615+
//
616+
// The network must be a UDP network name; see func Dial for details.
617+
func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
618+
ctx, cancel := d.dialCtx(ctx)
619+
defer cancel()
620+
return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
621+
}
622+
623+
// DialIP acts like Dial for IP networks using the provided context.
624+
//
625+
// The provided Context must be non-nil. If the context expires before
626+
// the connection is complete, an error is returned. Once successfully
627+
// connected, any expiration of the context will not affect the
628+
// connection.
629+
//
630+
// The network must be an IP network name; see func Dial for details.
631+
func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
632+
ctx, cancel := d.dialCtx(ctx)
633+
defer cancel()
634+
return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
635+
}
636+
637+
// DialUnix acts like Dial for Unix networks using the provided context.
638+
//
639+
// The provided Context must be non-nil. If the context expires before
640+
// the connection is complete, an error is returned. Once successfully
641+
// connected, any expiration of the context will not affect the
642+
// connection.
643+
//
644+
// The network must be a Unix network name; see func Dial for details.
645+
func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
646+
ctx, cancel := d.dialCtx(ctx)
647+
defer cancel()
648+
return dialUnix(ctx, d, network, laddr, raddr)
649+
}
650+
581651
// dialParallel races two copies of dialSerial, giving the first a
582652
// head start. It returns the first established connection and
583653
// closes the others. Otherwise it returns an error from the first

src/net/dial_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"internal/testenv"
1313
"io"
14+
"net/netip"
1415
"os"
1516
"runtime"
1617
"strings"
@@ -1064,6 +1065,99 @@ func TestDialerControlContext(t *testing.T) {
10641065
})
10651066
}
10661067

1068+
func TestDialContext(t *testing.T) {
1069+
switch runtime.GOOS {
1070+
case "plan9":
1071+
t.Skipf("not supported on %s", runtime.GOOS)
1072+
case "js", "wasip1":
1073+
t.Skipf("skipping: fake net does not support Dialer.ControlContext")
1074+
}
1075+
1076+
t.Run("StreamDial", func(t *testing.T) {
1077+
var err error
1078+
for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
1079+
if !testableNetwork(network) {
1080+
continue
1081+
}
1082+
ln := newLocalListener(t, network)
1083+
defer ln.Close()
1084+
var id int
1085+
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
1086+
id = ctx.Value("id").(int)
1087+
return controlOnConnSetup(network, address, c)
1088+
}}
1089+
var c Conn
1090+
switch network {
1091+
case "tcp", "tcp4", "tcp6":
1092+
raddr, err := netip.ParseAddrPort(ln.Addr().String())
1093+
if err != nil {
1094+
t.Error(err)
1095+
continue
1096+
}
1097+
c, err = d.DialTCP(context.WithValue(context.Background(), "id", i+1), network, (*TCPAddr)(nil).AddrPort(), raddr)
1098+
case "unix", "unixpacket":
1099+
raddr, err := ResolveUnixAddr(network, ln.Addr().String())
1100+
if err != nil {
1101+
t.Error(err)
1102+
continue
1103+
}
1104+
c, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
1105+
}
1106+
if err != nil {
1107+
t.Error(err)
1108+
continue
1109+
}
1110+
if id != i+1 {
1111+
t.Errorf("%s: got id %d, want %d", network, id, i+1)
1112+
}
1113+
c.Close()
1114+
}
1115+
})
1116+
t.Run("PacketDial", func(t *testing.T) {
1117+
var err error
1118+
for i, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
1119+
if !testableNetwork(network) {
1120+
continue
1121+
}
1122+
c1 := newLocalPacketListener(t, network)
1123+
if network == "unixgram" {
1124+
defer os.Remove(c1.LocalAddr().String())
1125+
}
1126+
defer c1.Close()
1127+
var id int
1128+
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
1129+
id = ctx.Value("id").(int)
1130+
return controlOnConnSetup(network, address, c)
1131+
}}
1132+
var c2 Conn
1133+
switch network {
1134+
case "udp", "udp4", "udp6":
1135+
raddr, err := netip.ParseAddrPort(c1.LocalAddr().String())
1136+
if err != nil {
1137+
t.Error(err)
1138+
continue
1139+
}
1140+
c2, err = d.DialUDP(context.WithValue(context.Background(), "id", i+1), network, (*UDPAddr)(nil).AddrPort(), raddr)
1141+
case "unixgram":
1142+
raddr, err := ResolveUnixAddr(network, c1.LocalAddr().String())
1143+
if err != nil {
1144+
t.Error(err)
1145+
continue
1146+
}
1147+
c2, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
1148+
}
1149+
if err != nil {
1150+
t.Error(err)
1151+
continue
1152+
}
1153+
if id != i+1 {
1154+
t.Errorf("%s: got id %d, want %d", network, id, i+1)
1155+
}
1156+
c2.Close()
1157+
}
1158+
})
1159+
}
1160+
10671161
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
10681162
// except on non-Linux, non-mobile builders it permits the test to
10691163
// run in -short mode.

src/net/iprawsock.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package net
66

77
import (
88
"context"
9+
"net/netip"
910
"syscall"
1011
)
1112

@@ -24,6 +25,13 @@ import (
2425
// BUG(mikio): On JS and Plan 9, methods and functions related
2526
// to IPConn are not implemented.
2627

28+
func ipAddrFromAddr(addr netip.Addr) *IPAddr {
29+
return &IPAddr{
30+
IP: addr.AsSlice(),
31+
Zone: addr.Zone(),
32+
}
33+
}
34+
2735
// IPAddr represents the address of an IP end point.
2836
type IPAddr struct {
2937
IP IP
@@ -206,11 +214,18 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
206214
// If the IP field of raddr is nil or an unspecified IP address, the
207215
// local system is assumed.
208216
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
217+
return dialIP(context.Background(), nil, network, laddr, raddr)
218+
}
219+
220+
func dialIP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *IPAddr) (*IPConn, error) {
209221
if raddr == nil {
210222
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
211223
}
212224
sd := &sysDialer{network: network, address: raddr.String()}
213-
c, err := sd.dialIP(context.Background(), laddr, raddr)
225+
if dialer != nil {
226+
sd.Dialer = *dialer
227+
}
228+
c, err := sd.dialIP(ctx, laddr, raddr)
214229
if err != nil {
215230
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
216231
}

src/net/tcpsock.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ func newTCPConn(fd *netFD, keepAliveIdle time.Duration, keepAliveCfg KeepAliveCo
315315
// If the IP field of raddr is nil or an unspecified IP address, the
316316
// local system is assumed.
317317
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
318+
return dialTCP(context.Background(), nil, network, laddr, raddr)
319+
}
320+
321+
func dialTCP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
318322
switch network {
319323
case "tcp", "tcp4", "tcp6":
320324
default:
@@ -328,10 +332,13 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
328332
c *TCPConn
329333
err error
330334
)
335+
if dialer != nil {
336+
sd.Dialer = *dialer
337+
}
331338
if sd.MultipathTCP() {
332-
c, err = sd.dialMPTCP(context.Background(), laddr, raddr)
339+
c, err = sd.dialMPTCP(ctx, laddr, raddr)
333340
} else {
334-
c, err = sd.dialTCP(context.Background(), laddr, raddr)
341+
c, err = sd.dialTCP(ctx, laddr, raddr)
335342
}
336343
if err != nil {
337344
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}

src/net/udpsock.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
285285
// If the IP field of raddr is nil or an unspecified IP address, the
286286
// local system is assumed.
287287
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
288+
return dialUDP(context.Background(), nil, network, laddr, raddr)
289+
}
290+
291+
func dialUDP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
288292
switch network {
289293
case "udp", "udp4", "udp6":
290294
default:
@@ -294,7 +298,10 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
294298
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
295299
}
296300
sd := &sysDialer{network: network, address: raddr.String()}
297-
c, err := sd.dialUDP(context.Background(), laddr, raddr)
301+
if dialer != nil {
302+
sd.Dialer = *dialer
303+
}
304+
c, err := sd.dialUDP(ctx, laddr, raddr)
298305
if err != nil {
299306
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
300307
}

src/net/unixsock.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,20 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
201201
// If laddr is non-nil, it is used as the local address for the
202202
// connection.
203203
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
204+
return dialUnix(context.Background(), nil, network, laddr, raddr)
205+
}
206+
207+
func dialUnix(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
204208
switch network {
205209
case "unix", "unixgram", "unixpacket":
206210
default:
207211
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
208212
}
209213
sd := &sysDialer{network: network, address: raddr.String()}
210-
c, err := sd.dialUnix(context.Background(), laddr, raddr)
214+
if dialer != nil {
215+
sd.Dialer = *dialer
216+
}
217+
c, err := sd.dialUnix(ctx, laddr, raddr)
211218
if err != nil {
212219
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
213220
}

0 commit comments

Comments
 (0)