Skip to content

Commit a109f8d

Browse files
committed
expose methods from underlying connection types
This sucks but I can't think of a better way to do this. We really do want to expose these features and doing so through type assertions is very go-like.
1 parent 8792ba0 commit a109f8d

File tree

2 files changed

+68
-24
lines changed

2 files changed

+68
-24
lines changed

net.go

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,64 @@ type Conn interface {
2828
RemoteMultiaddr() ma.Multiaddr
2929
}
3030

31-
// WrapNetConn wraps a net.Conn object with a Multiaddr
32-
// friendly Conn.
31+
type halfOpen interface {
32+
net.Conn
33+
CloseRead() error
34+
CloseWrite() error
35+
}
36+
37+
func wrap(nconn net.Conn, laddr, raddr ma.Multiaddr) Conn {
38+
endpts := maEndpoints{
39+
laddr: laddr,
40+
raddr: raddr,
41+
}
42+
// This sucks. However, it's the only way to reliably expose the
43+
// underlying methods. This way, users that need access to, e.g.,
44+
// CloseRead and CloseWrite, can do so via type assertions.
45+
switch nconn := nconn.(type) {
46+
case *net.TCPConn:
47+
return &struct {
48+
*net.TCPConn
49+
maEndpoints
50+
}{nconn, endpts}
51+
case *net.UDPConn:
52+
return &struct {
53+
*net.UDPConn
54+
maEndpoints
55+
}{nconn, endpts}
56+
case *net.IPConn:
57+
return &struct {
58+
*net.IPConn
59+
maEndpoints
60+
}{nconn, endpts}
61+
case *net.UnixConn:
62+
return &struct {
63+
*net.UnixConn
64+
maEndpoints
65+
}{nconn, endpts}
66+
case halfOpen:
67+
return &struct {
68+
halfOpen
69+
maEndpoints
70+
}{nconn, endpts}
71+
default:
72+
return &struct {
73+
net.Conn
74+
maEndpoints
75+
}{nconn, endpts}
76+
}
77+
}
78+
79+
// WrapNetConn wraps a net.Conn object with a Multiaddr friendly Conn.
80+
//
81+
// This function does it's best to avoid "hiding" methods exposed by the wrapped
82+
// type. Guarantees:
83+
//
84+
// * If the wrapped connection exposes the "half-open" closer methods
85+
// (CloseWrite, CloseRead), these will be available on the wrapped connection
86+
// via type assertions.
87+
// * If the wrapped connection is a UnixConn, IPConn, TCPConn, or UDPConn, all
88+
// methods on these wrapped connections will be available via type assertions.
3389
func WrapNetConn(nconn net.Conn) (Conn, error) {
3490
if nconn == nil {
3591
return nil, fmt.Errorf("failed to convert nconn.LocalAddr: nil")
@@ -45,30 +101,23 @@ func WrapNetConn(nconn net.Conn) (Conn, error) {
45101
return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
46102
}
47103

48-
return &maConn{
49-
Conn: nconn,
50-
laddr: laddr,
51-
raddr: raddr,
52-
}, nil
104+
return wrap(nconn, laddr, raddr), nil
53105
}
54106

55-
// maConn implements the Conn interface. It's a thin wrapper
56-
// around a net.Conn
57-
type maConn struct {
58-
net.Conn
107+
type maEndpoints struct {
59108
laddr ma.Multiaddr
60109
raddr ma.Multiaddr
61110
}
62111

63112
// LocalMultiaddr returns the local address associated with
64113
// this connection
65-
func (c *maConn) LocalMultiaddr() ma.Multiaddr {
114+
func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr {
66115
return c.laddr
67116
}
68117

69118
// RemoteMultiaddr returns the remote address associated with
70119
// this connection
71-
func (c *maConn) RemoteMultiaddr() ma.Multiaddr {
120+
func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
72121
return c.raddr
73122
}
74123

@@ -135,12 +184,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
135184
return nil, err
136185
}
137186
}
138-
139-
return &maConn{
140-
Conn: nconn,
141-
laddr: local,
142-
raddr: remote,
143-
}, nil
187+
return wrap(nconn, local, remote), nil
144188
}
145189

146190
// Dial connects to a remote address. It uses an underlying net.Conn,
@@ -204,11 +248,7 @@ func (l *maListener) Accept() (Conn, error) {
204248
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err)
205249
}
206250

207-
return &maConn{
208-
Conn: nconn,
209-
laddr: l.laddr,
210-
raddr: raddr,
211-
}, nil
251+
return wrap(nconn, l.laddr, raddr), nil
212252
}
213253

214254
// Multiaddr returns the listener's (local) Multiaddr.

net_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,14 @@ func TestWrapNetConn(t *testing.T) {
407407
defer wg.Done()
408408
cB, err := listener.Accept()
409409
checkErr(err, "failed to accept")
410+
_ = cB.(halfOpen)
410411
cB.Close()
411412
}()
412413

413414
cA, err := net.Dial("tcp", listener.Addr().String())
414415
checkErr(err, "failed to dial")
415416
defer cA.Close()
417+
_ = cA.(halfOpen)
416418

417419
lmaddr, err := FromNetAddr(cA.LocalAddr())
418420
checkErr(err, "failed to get local addr")
@@ -422,6 +424,8 @@ func TestWrapNetConn(t *testing.T) {
422424
mcA, err := WrapNetConn(cA)
423425
checkErr(err, "failed to wrap conn")
424426

427+
_ = mcA.(halfOpen)
428+
425429
if mcA.LocalAddr().String() != cA.LocalAddr().String() {
426430
t.Error("wrapped conn local addr differs")
427431
}

0 commit comments

Comments
 (0)