Skip to content
This repository was archived by the owner on Feb 7, 2024. It is now read-only.

Commit 98dafb5

Browse files
authored
Merge pull request #202 from gfanton/dev/uds-support
Add Unix Domain Socket support to Shell client
2 parents 5d34b6b + 235270b commit 98dafb5

File tree

4 files changed

+88
-9
lines changed

4 files changed

+88
-9
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ os:
44
language: go
55

66
go:
7-
- 1.11.x
7+
- 1.13.x
88

99
services:
1010
- docker

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
go 1.13
2+
13
module github.com/ipfs/go-ipfs-api
24

35
require (

shell.go

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"io/ioutil"
12+
"net"
1213
gohttp "net/http"
1314
"os"
1415
"path"
@@ -72,20 +73,51 @@ func NewShell(url string) *Shell {
7273
return NewShellWithClient(url, c)
7374
}
7475

75-
func NewShellWithClient(url string, c *gohttp.Client) *Shell {
76-
if a, err := ma.NewMultiaddr(url); err == nil {
77-
_, host, err := manet.DialArgs(a)
78-
if err == nil {
79-
url = host
80-
}
81-
}
76+
func NewShellWithClient(url string, client *gohttp.Client) *Shell {
8277
var sh Shell
78+
8379
sh.url = url
84-
sh.httpcli = *c
80+
sh.httpcli = *client
8581
// We don't support redirects.
8682
sh.httpcli.CheckRedirect = func(_ *gohttp.Request, _ []*gohttp.Request) error {
8783
return fmt.Errorf("unexpected redirect")
8884
}
85+
86+
maddr, err := ma.NewMultiaddr(url)
87+
if err != nil {
88+
return &sh
89+
}
90+
91+
network, host, err := manet.DialArgs(maddr)
92+
if err != nil {
93+
return &sh
94+
}
95+
96+
if network == "unix" {
97+
sh.url = network
98+
99+
var tptCopy *gohttp.Transport
100+
if tpt, ok := sh.httpcli.Transport.(*gohttp.Transport); ok && tpt.DialContext == nil {
101+
tptCopy = tpt.Clone()
102+
} else if sh.httpcli.Transport == nil {
103+
tptCopy = &gohttp.Transport{
104+
Proxy: gohttp.ProxyFromEnvironment,
105+
DisableKeepAlives: true,
106+
}
107+
} else {
108+
// custom Transport or custom Dialer, we are done here
109+
return &sh
110+
}
111+
112+
tptCopy.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
113+
return net.Dial("unix", host)
114+
}
115+
116+
sh.httpcli.Transport = tptCopy
117+
} else {
118+
sh.url = host
119+
}
120+
89121
return &sh
90122
}
91123

shell_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
package shell
22

33
import (
4+
"bufio"
45
"bytes"
56
"context"
67
"crypto/md5"
78
"fmt"
89
"io"
10+
"io/ioutil"
911
"math/rand"
12+
"net"
13+
"net/http"
14+
"os"
15+
"path/filepath"
1016
"sort"
1117
"strings"
1218
"testing"
@@ -361,6 +367,45 @@ func TestSwarmPeers(t *testing.T) {
361367
is.Nil(err)
362368
}
363369

370+
// TestNewShellWithUnixSocket only check that http client is well configured to
371+
// perform http request on unix socket address
372+
func TestNewShellWithUnixSocket(t *testing.T) {
373+
is := is.New(t)
374+
375+
// setup uds temporary dir
376+
path, err := ioutil.TempDir("", "uds-test")
377+
is.Nil(err)
378+
379+
defer os.RemoveAll(path)
380+
381+
// listen on sock path
382+
sockpath := filepath.Join(path, "sock")
383+
lsock, err := net.Listen("unix", sockpath)
384+
is.Nil(err)
385+
386+
defer lsock.Close()
387+
388+
// handle simple `hello` route
389+
mux := http.NewServeMux()
390+
mux.HandleFunc("/api/v0/hello", func(w http.ResponseWriter, _ *http.Request) {
391+
fmt.Fprint(w, "Hello World\n")
392+
})
393+
394+
go http.Serve(lsock, mux)
395+
396+
// create shell with "/unix/<sockpath>" multiaddr
397+
shell := NewShell("/unix/" + sockpath)
398+
res, err := shell.Request("hello").Send(context.Background())
399+
is.Nil(err)
400+
401+
defer res.Output.Close()
402+
403+
// read hello world from body
404+
str, err := bufio.NewReader(res.Output).ReadString('\n')
405+
is.Nil(err)
406+
is.Equal(str, "Hello World\n")
407+
}
408+
364409
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
365410
const (
366411
letterIdxBits = 6 // 6 bits to represent a letter index

0 commit comments

Comments
 (0)