Skip to content

Commit 1879060

Browse files
authored
Merge pull request #50 from multiformats/feat/unix-sockets
Add support for unix sockets
2 parents 0959afa + a86fb2b commit 1879060

File tree

4 files changed

+94
-7
lines changed

4 files changed

+94
-7
lines changed

convert.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package manet
33
import (
44
"fmt"
55
"net"
6+
"path/filepath"
67

78
ma "github.com/multiformats/go-multiaddr"
89
madns "github.com/multiformats/go-multiaddr-dns"
@@ -61,6 +62,8 @@ func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) {
6162
return net.ResolveUDPAddr(network, host)
6263
case "ip", "ip4", "ip6":
6364
return net.ResolveIPAddr(network, host)
65+
case "unix":
66+
return net.ResolveUnixAddr(network, host)
6467
}
6568

6669
return nil, fmt.Errorf("network not supported: %s", network)
@@ -96,7 +99,8 @@ func FromIP(ip net.IP) (ma.Multiaddr, error) {
9699

97100
// DialArgs is a convenience function that returns network and address as
98101
// expected by net.Dial. See https://godoc.org/net#Dial for an overview of
99-
// possible return values (we do not support the unix* ones yet).
102+
// possible return values (we do not support the unixpacket ones yet). Unix
103+
// addresses do not, at present, compose.
100104
func DialArgs(m ma.Multiaddr) (string, string, error) {
101105
var (
102106
zone, network, ip, port string
@@ -137,6 +141,10 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
137141
hostname = true
138142
ip = c.Value()
139143
return true
144+
case ma.P_UNIX:
145+
network = "unix"
146+
ip = c.Value()
147+
return false
140148
}
141149
case "ip4":
142150
switch c.Protocol().Code {
@@ -184,6 +192,8 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
184192
return network, ip + ":" + port, nil
185193
}
186194
return network, "[" + ip + "]" + ":" + port, nil
195+
case "unix":
196+
return network, ip, nil
187197
default:
188198
return "", "", fmt.Errorf("%s is not a 'thin waist' address", m)
189199
}
@@ -248,3 +258,12 @@ func parseIPPlusNetAddr(a net.Addr) (ma.Multiaddr, error) {
248258
}
249259
return FromIP(ac.IP)
250260
}
261+
262+
func parseUnixNetAddr(a net.Addr) (ma.Multiaddr, error) {
263+
ac, ok := a.(*net.UnixAddr)
264+
if !ok {
265+
return nil, errIncorrectNetAddr
266+
}
267+
cleaned := filepath.Clean(ac.Name)
268+
return ma.NewComponent("unix", cleaned)
269+
}

net.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
167167
// ok, Dial!
168168
var nconn net.Conn
169169
switch rnet {
170-
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
170+
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix":
171171
nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr)
172172
if err != nil {
173173
return nil, err
@@ -178,7 +178,9 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
178178

179179
// get local address (pre-specified or assigned within net.Conn)
180180
local := d.LocalAddr
181-
if local == nil {
181+
// This block helps us avoid parsing addresses in transports (such as unix
182+
// sockets) that don't have local addresses when dialing out.
183+
if local == nil && nconn.LocalAddr().String() != "" {
182184
local, err = FromNetAddr(nconn.LocalAddr())
183185
if err != nil {
184186
return nil, err
@@ -243,9 +245,14 @@ func (l *maListener) Accept() (Conn, error) {
243245
return nil, err
244246
}
245247

246-
raddr, err := FromNetAddr(nconn.RemoteAddr())
247-
if err != nil {
248-
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err)
248+
var raddr ma.Multiaddr
249+
// This block protects us in transports (i.e. unix sockets) that don't have
250+
// remote addresses for inbound connections.
251+
if nconn.RemoteAddr().String() != "" {
252+
raddr, err = FromNetAddr(nconn.RemoteAddr())
253+
if err != nil {
254+
return nil, fmt.Errorf("failed to convert conn.RemoteAddr: %s", err)
255+
}
249256
}
250257

251258
return wrap(nconn, l.laddr, raddr), nil

net_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ package manet
33
import (
44
"bytes"
55
"fmt"
6+
"io/ioutil"
67
"net"
8+
"os"
9+
"path/filepath"
710
"sync"
811
"testing"
12+
"time"
913

1014
ma "github.com/multiformats/go-multiaddr"
1115
)
@@ -75,6 +79,62 @@ func TestDial(t *testing.T) {
7579
wg.Wait()
7680
}
7781

82+
func TestUnixSockets(t *testing.T) {
83+
dir, err := ioutil.TempDir(os.TempDir(), "manettest")
84+
if err != nil {
85+
t.Fatal(err)
86+
}
87+
path := filepath.Join(dir, "listen.sock")
88+
maddr := newMultiaddr(t, "/unix/"+path)
89+
90+
listener, err := Listen(maddr)
91+
if err != nil {
92+
t.Fatal(err)
93+
}
94+
95+
payload := []byte("hello")
96+
97+
// listen
98+
done := make(chan struct{}, 1)
99+
go func() {
100+
conn, err := listener.Accept()
101+
if err != nil {
102+
t.Fatal(err)
103+
}
104+
defer conn.Close()
105+
buf := make([]byte, 1024)
106+
n, err := conn.Read(buf)
107+
if err != nil {
108+
t.Fatal(err)
109+
}
110+
if n != len(payload) {
111+
t.Fatal("failed to read appropriate number of bytes")
112+
}
113+
if !bytes.Equal(buf[0:n], payload) {
114+
t.Fatal("payload did not match")
115+
}
116+
done <- struct{}{}
117+
}()
118+
119+
// dial
120+
conn, err := Dial(maddr)
121+
if err != nil {
122+
t.Fatal(err)
123+
}
124+
n, err := conn.Write(payload)
125+
if err != nil {
126+
t.Fatal(err)
127+
}
128+
if n != len(payload) {
129+
t.Fatal("failed to write appropriate number of bytes")
130+
}
131+
select {
132+
case <-done:
133+
case <-time.After(1 * time.Second):
134+
t.Fatal("timed out waiting for read")
135+
}
136+
}
137+
78138
func TestListen(t *testing.T) {
79139

80140
maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322")

registry.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ func init() {
2121
defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6")
2222
defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6")
2323
defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net")
24+
defaultCodecs.RegisterFromNetAddr(parseUnixNetAddr, "unix")
2425

25-
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4")
26+
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4", "unix")
2627
}
2728

2829
// CodecMap holds a map of NetCodecs indexed by their Protocol ID

0 commit comments

Comments
 (0)