@@ -28,8 +28,64 @@ type Conn interface {
28
28
RemoteMultiaddr () ma.Multiaddr
29
29
}
30
30
31
- // WrapNetConn wraps a net.Conn object with a Multiaddr
32
- // friendly Conn.
31
+ type halfOpen interface {
32
+ net.Conn
33
+ CloseRead () error
34
+ CloseWrite () error
35
+ }
36
+
37
+ func wrap (nconn net.Conn , laddr , raddr ma.Multiaddr ) Conn {
38
+ endpts := maEndpoints {
39
+ laddr : laddr ,
40
+ raddr : raddr ,
41
+ }
42
+ // This sucks. However, it's the only way to reliably expose the
43
+ // underlying methods. This way, users that need access to, e.g.,
44
+ // CloseRead and CloseWrite, can do so via type assertions.
45
+ switch nconn := nconn .(type ) {
46
+ case * net.TCPConn :
47
+ return & struct {
48
+ * net.TCPConn
49
+ maEndpoints
50
+ }{nconn , endpts }
51
+ case * net.UDPConn :
52
+ return & struct {
53
+ * net.UDPConn
54
+ maEndpoints
55
+ }{nconn , endpts }
56
+ case * net.IPConn :
57
+ return & struct {
58
+ * net.IPConn
59
+ maEndpoints
60
+ }{nconn , endpts }
61
+ case * net.UnixConn :
62
+ return & struct {
63
+ * net.UnixConn
64
+ maEndpoints
65
+ }{nconn , endpts }
66
+ case halfOpen :
67
+ return & struct {
68
+ halfOpen
69
+ maEndpoints
70
+ }{nconn , endpts }
71
+ default :
72
+ return & struct {
73
+ net.Conn
74
+ maEndpoints
75
+ }{nconn , endpts }
76
+ }
77
+ }
78
+
79
+ // WrapNetConn wraps a net.Conn object with a Multiaddr friendly Conn.
80
+ //
81
+ // This function does it's best to avoid "hiding" methods exposed by the wrapped
82
+ // type. Guarantees:
83
+ //
84
+ // * If the wrapped connection exposes the "half-open" closer methods
85
+ // (CloseWrite, CloseRead), these will be available on the wrapped connection
86
+ // via type assertions.
87
+ // * If the wrapped connection is a UnixConn, IPConn, TCPConn, or UDPConn, all
88
+ // methods on these wrapped connections will be available via type assertions.
33
89
func WrapNetConn (nconn net.Conn ) (Conn , error ) {
34
90
if nconn == nil {
35
91
return nil , fmt .Errorf ("failed to convert nconn.LocalAddr: nil" )
@@ -45,30 +101,23 @@ func WrapNetConn(nconn net.Conn) (Conn, error) {
45
101
return nil , fmt .Errorf ("failed to convert nconn.RemoteAddr: %s" , err )
46
102
}
47
103
48
- return & maConn {
49
- Conn : nconn ,
50
- laddr : laddr ,
51
- raddr : raddr ,
52
- }, nil
104
+ return wrap (nconn , laddr , raddr ), nil
53
105
}
54
106
55
- // maConn implements the Conn interface. It's a thin wrapper
56
- // around a net.Conn
57
- type maConn struct {
58
- net.Conn
107
+ type maEndpoints struct {
59
108
laddr ma.Multiaddr
60
109
raddr ma.Multiaddr
61
110
}
62
111
63
112
// LocalMultiaddr returns the local address associated with
64
113
// this connection
65
- func (c * maConn ) LocalMultiaddr () ma.Multiaddr {
114
+ func (c * maEndpoints ) LocalMultiaddr () ma.Multiaddr {
66
115
return c .laddr
67
116
}
68
117
69
118
// RemoteMultiaddr returns the remote address associated with
70
119
// this connection
71
- func (c * maConn ) RemoteMultiaddr () ma.Multiaddr {
120
+ func (c * maEndpoints ) RemoteMultiaddr () ma.Multiaddr {
72
121
return c .raddr
73
122
}
74
123
@@ -135,12 +184,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
135
184
return nil , err
136
185
}
137
186
}
138
-
139
- return & maConn {
140
- Conn : nconn ,
141
- laddr : local ,
142
- raddr : remote ,
143
- }, nil
187
+ return wrap (nconn , local , remote ), nil
144
188
}
145
189
146
190
// Dial connects to a remote address. It uses an underlying net.Conn,
@@ -204,11 +248,7 @@ func (l *maListener) Accept() (Conn, error) {
204
248
return nil , fmt .Errorf ("failed to convert connn.RemoteAddr: %s" , err )
205
249
}
206
250
207
- return & maConn {
208
- Conn : nconn ,
209
- laddr : l .laddr ,
210
- raddr : raddr ,
211
- }, nil
251
+ return wrap (nconn , l .laddr , raddr ), nil
212
252
}
213
253
214
254
// Multiaddr returns the listener's (local) Multiaddr.
0 commit comments