Skip to content
This repository was archived by the owner on Jan 21, 2020. It is now read-only.

Commit a8b26d3

Browse files
author
David Chung
authored
support headers in template fetch (#444)
Signed-off-by: David Chung <[email protected]>
1 parent dd74ee6 commit a8b26d3

File tree

5 files changed

+146
-33
lines changed

5 files changed

+146
-33
lines changed

pkg/rpc/server/info_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestFetchAPIInfoFromPlugin(t *testing.T) {
3434
server, err := StartPluginAtPath(socketPath, rpc_instance.PluginServer(&testing_instance.Plugin{}))
3535
require.NoError(t, err)
3636

37-
buff, err := template.Fetch(url, template.Options{SocketDir: dir})
37+
buff, err := template.Fetch(url, template.Options{SocketDir: dir}, nil)
3838
require.NoError(t, err)
3939

4040
decoded, err := template.FromJSON(buff)
@@ -45,7 +45,7 @@ func TestFetchAPIInfoFromPlugin(t *testing.T) {
4545
require.Equal(t, "Instance", result)
4646

4747
url = "unix://" + host + "/info/functions.json"
48-
buff, err = template.Fetch(url, template.Options{SocketDir: dir})
48+
buff, err = template.Fetch(url, template.Options{SocketDir: dir}, nil)
4949
require.NoError(t, err)
5050

5151
server.Stop()
@@ -91,7 +91,7 @@ func TestFetchFunctionsFromPlugin(t *testing.T) {
9191
server, err := StartPluginAtPath(socketPath, rpc_flavor.PluginServer(&exporter{&testing_flavor.Plugin{}}))
9292
require.NoError(t, err)
9393

94-
buff, err := template.Fetch(url, template.Options{SocketDir: dir})
94+
buff, err := template.Fetch(url, template.Options{SocketDir: dir}, nil)
9595
require.NoError(t, err)
9696

9797
decoded, err := template.FromJSON(buff)

pkg/template/fetch.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
// Fetch fetchs content from the given URL string. Supported schemes are http:// https:// file:// unix://
14-
func Fetch(s string, opt Options) ([]byte, error) {
14+
func Fetch(s string, opt Options, customize func(*http.Request)) ([]byte, error) {
1515
u, err := url.Parse(s)
1616
if err != nil {
1717
return nil, err
@@ -21,12 +21,7 @@ func Fetch(s string, opt Options) ([]byte, error) {
2121
return ioutil.ReadFile(u.Path)
2222

2323
case "http", "https":
24-
resp, err := http.Get(u.String())
25-
if err != nil {
26-
return nil, err
27-
}
28-
defer resp.Body.Close()
29-
return ioutil.ReadAll(resp.Body)
24+
return doHTTPGet(u, customize, &http.Client{})
3025

3126
case "unix":
3227
// unix: will look for a socket that matches the host name at a
@@ -36,17 +31,30 @@ func Fetch(s string, opt Options) ([]byte, error) {
3631
return nil, err
3732
}
3833
u.Scheme = "http"
39-
resp, err := c.Get(u.String())
40-
if err != nil {
41-
return nil, err
42-
}
43-
defer resp.Body.Close()
44-
return ioutil.ReadAll(resp.Body)
34+
return doHTTPGet(u, customize, c)
4535
}
4636

4737
return nil, fmt.Errorf("unsupported url:%s", s)
4838
}
4939

40+
func doHTTPGet(u *url.URL, customize func(*http.Request), client *http.Client) ([]byte, error) {
41+
req, err := http.NewRequest("GET", u.String(), nil)
42+
if err != nil {
43+
return nil, err
44+
}
45+
46+
if customize != nil {
47+
customize(req)
48+
}
49+
50+
resp, err := client.Do(req)
51+
if err != nil {
52+
return nil, err
53+
}
54+
defer resp.Body.Close()
55+
return ioutil.ReadAll(resp.Body)
56+
}
57+
5058
func socketClient(u *url.URL, socketDir string) (*http.Client, error) {
5159
socketPath := filepath.Join(socketDir, u.Host)
5260
if f, err := os.Stat(socketPath); err != nil {

pkg/template/funcs.go

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/json"
66
"fmt"
7+
"net/http"
78
"reflect"
89
"strings"
910
"time"
@@ -134,8 +135,45 @@ func IndexOf(srch interface{}, array interface{}, strictOptional ...bool) int {
134135
return -1
135136
}
136137

138+
// given optional args in a template function call, extra headers and the context
139+
func headersAndContext(opt ...interface{}) (headers map[string][]string, context interface{}) {
140+
if len(opt) == 0 {
141+
return
142+
}
143+
// scan through all the args and if it's a string of the form x=y, then use as header
144+
// the element that doesn't follow the form is the context
145+
headers = map[string][]string{}
146+
for _, v := range opt {
147+
if vv, is := v.(string); is && strings.Index(vv, "=") > 0 {
148+
kv := strings.Split(vv, "=")
149+
key := kv[0]
150+
value := ""
151+
if len(kv) == 2 {
152+
value = kv[1]
153+
}
154+
if _, has := headers[key]; !has {
155+
headers[key] = []string{value}
156+
} else {
157+
headers[key] = append(headers[key], value)
158+
}
159+
} else {
160+
context = v
161+
}
162+
}
163+
return
164+
}
165+
166+
func setHeaders(req *http.Request, headers map[string][]string) {
167+
for k, vv := range headers {
168+
for _, v := range vv {
169+
req.Header.Add(k, v)
170+
}
171+
}
172+
}
173+
137174
// DefaultFuncs returns a list of default functions for binding in the template
138175
func (t *Template) DefaultFuncs() []Function {
176+
139177
return []Function{
140178
{
141179
Name: "source",
@@ -146,10 +184,7 @@ func (t *Template) DefaultFuncs() []Function {
146184
"as the calling template. The context (e.g. variables) of the calling template as a result can be mutated.",
147185
},
148186
Func: func(p string, opt ...interface{}) (string, error) {
149-
var o interface{}
150-
if len(opt) > 0 {
151-
o = opt[0]
152-
}
187+
headers, context := headersAndContext(opt...)
153188
loc := p
154189
if strings.Index(loc, "str://") == -1 {
155190
buff, err := getURL(t.url, p)
@@ -158,7 +193,7 @@ func (t *Template) DefaultFuncs() []Function {
158193
}
159194
loc = buff
160195
}
161-
sourced, err := NewTemplate(loc, t.options)
196+
sourced, err := NewTemplateCustom(loc, t.options, func(req *http.Request) { setHeaders(req, headers) })
162197
if err != nil {
163198
return "", err
164199
}
@@ -167,11 +202,11 @@ func (t *Template) DefaultFuncs() []Function {
167202
sourced.forkFrom(t)
168203
sourced.context = t.context
169204

170-
if o == nil {
171-
o = sourced.context
205+
if context == nil {
206+
context = sourced.context
172207
}
173208
// TODO(chungers) -- let the sourced template define new functions that can be called in the parent.
174-
return sourced.Render(o)
209+
return sourced.Render(context)
175210
},
176211
},
177212
{
@@ -184,10 +219,7 @@ func (t *Template) DefaultFuncs() []Function {
184219
"be visible in the calling template's context.",
185220
},
186221
Func: func(p string, opt ...interface{}) (string, error) {
187-
var o interface{}
188-
if len(opt) > 0 {
189-
o = opt[0]
190-
}
222+
headers, context := headersAndContext(opt...)
191223
loc := p
192224
if strings.Index(loc, "str://") == -1 {
193225
buff, err := getURL(t.url, p)
@@ -196,7 +228,7 @@ func (t *Template) DefaultFuncs() []Function {
196228
}
197229
loc = buff
198230
}
199-
included, err := NewTemplate(loc, t.options)
231+
included, err := NewTemplateCustom(loc, t.options, func(req *http.Request) { setHeaders(req, headers) })
200232
if err != nil {
201233
return "", err
202234
}
@@ -206,11 +238,11 @@ func (t *Template) DefaultFuncs() []Function {
206238
}
207239
included.context = dotCopy
208240

209-
if o == nil {
210-
o = included.context
241+
if context == nil {
242+
context = included.context
211243
}
212244

213-
return included.Render(o)
245+
return included.Render(context)
214246
},
215247
},
216248
{

pkg/template/integration_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@ import (
88
"sync"
99
"testing"
1010

11+
"github.com/docker/infrakit/pkg/log"
12+
"github.com/docker/infrakit/pkg/types"
1113
"github.com/stretchr/testify/require"
1214
)
1315

16+
var logger = log.New("module", "template")
17+
1418
func TestTemplateInclusionFromDifferentSources(t *testing.T) {
1519
prefix := testSetupTemplates(t, testFiles)
1620

@@ -256,3 +260,52 @@ func TestWithFunctions(t *testing.T) {
256260
require.NoError(t, err)
257261
require.Equal(t, "hello=1", view)
258262
}
263+
264+
func TestSourceWithHeaders(t *testing.T) {
265+
266+
h, context := headersAndContext("foo=bar")
267+
logger.Info("result", "context", context, "headers", h)
268+
require.Equal(t, interface{}(nil), context)
269+
require.Equal(t, map[string][]string{"foo": {"bar"}}, h)
270+
271+
h, context = headersAndContext("foo=bar", "bar=baz", 224)
272+
logger.Info("result", "context", context, "headers", h)
273+
require.Equal(t, 224, context)
274+
require.Equal(t, map[string][]string{"foo": {"bar"}, "bar": {"baz"}}, h)
275+
276+
h, context = headersAndContext("foo=bar", "bar=baz")
277+
logger.Info("result", "context", context, "headers", h)
278+
require.Equal(t, nil, context)
279+
require.Equal(t, map[string][]string{"foo": {"bar"}, "bar": {"baz"}}, h)
280+
281+
h, context = headersAndContext("foo")
282+
logger.Info("result", "context", context, "headers", h)
283+
require.Equal(t, "foo", context)
284+
require.Equal(t, map[string][]string{}, h)
285+
286+
h, context = headersAndContext("foo=bar", map[string]string{"hello": "world"})
287+
logger.Info("result", "context", context, "headers", h)
288+
require.Equal(t, map[string]string{"hello": "world"}, context)
289+
require.Equal(t, map[string][]string{"foo": {"bar"}}, h)
290+
291+
// note we don't have to escape -- use the back quote and the string value is valid
292+
r := "{{ include `https://httpbin.org/headers` `A=B` `Foo=Bar` `Foo=Bar` `X=1` 100 }}"
293+
s := `{{ $resp := (source "str://` + r + `" | jsonDecode) }}{{ $resp.headers | jsonEncode}}`
294+
tt, err := NewTemplate("str://"+s, Options{})
295+
require.NoError(t, err)
296+
view, err := tt.Render(nil)
297+
require.NoError(t, err)
298+
299+
any := types.AnyString(view)
300+
headers := map[string]interface{}{}
301+
require.NoError(t, any.Decode(&headers))
302+
require.Equal(t, map[string]interface{}{
303+
"Foo": "Bar,Bar",
304+
"Host": "httpbin.org",
305+
"User-Agent": "Go-http-client/1.1",
306+
"A": "B",
307+
"X": "1",
308+
"Accept-Encoding": "gzip",
309+
"Connection": "close",
310+
}, headers)
311+
}

pkg/template/template.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"fmt"
66
"io"
7+
"net/http"
78
"reflect"
89
"strings"
910
"sync"
@@ -88,6 +89,25 @@ type Void string
8889

8990
const voidValue Void = ""
9091

92+
// NewTemplateCustom fetches the content at the url and allows configuration of the request
93+
// If the string begins with str:// as scheme, then the rest of the string is interpreted as the body of the template.
94+
func NewTemplateCustom(s string, opt Options, custom func(*http.Request)) (*Template, error) {
95+
var buff []byte
96+
contextURL := s
97+
// Special case of specifying the entire template as a string; otherwise treat as url
98+
if strings.Index(s, "str://") == 0 {
99+
buff = []byte(strings.Replace(s, "str://", "", 1))
100+
contextURL = defaultContextURL()
101+
} else {
102+
b, err := Fetch(s, opt, custom)
103+
if err != nil {
104+
return nil, err
105+
}
106+
buff = b
107+
}
108+
return NewTemplateFromBytes(buff, contextURL, opt)
109+
}
110+
91111
// NewTemplate fetches the content at the url and returns a template. If the string begins
92112
// with str:// as scheme, then the rest of the string is interpreted as the body of the template.
93113
func NewTemplate(s string, opt Options) (*Template, error) {
@@ -98,7 +118,7 @@ func NewTemplate(s string, opt Options) (*Template, error) {
98118
buff = []byte(strings.Replace(s, "str://", "", 1))
99119
contextURL = defaultContextURL()
100120
} else {
101-
b, err := Fetch(s, opt)
121+
b, err := Fetch(s, opt, nil)
102122
if err != nil {
103123
return nil, err
104124
}

0 commit comments

Comments
 (0)