Skip to content

Commit c784a5d

Browse files
author
marcel corso
committed
Merge branch 'master' into add-conversations
2 parents 430d19f + 12d131c commit c784a5d

File tree

10 files changed

+512
-34
lines changed

10 files changed

+512
-34
lines changed

client.go

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525

2626
const (
2727
// ClientVersion is used in User-Agent request header to provide server with API level.
28-
ClientVersion = "5.0.0"
28+
ClientVersion = "5.1.1"
2929

3030
// Endpoint points you to MessageBird REST API.
3131
Endpoint = "https://rest.messagebird.com"
@@ -47,6 +47,14 @@ type Client struct {
4747
DebugLog *log.Logger // Optional logger for debugging purposes
4848
}
4949

50+
type contentType string
51+
52+
const (
53+
contentTypeEmpty contentType = ""
54+
contentTypeJSON contentType = "application/json"
55+
contentTypeFormURLEncoded contentType = "application/x-www-form-urlencoded"
56+
)
57+
5058
// New creates a new MessageBird client object.
5159
func New(accessKey string) *Client {
5260
return &Client{
@@ -67,27 +75,26 @@ func (c *Client) Request(v interface{}, method, path string, data interface{}) e
6775
return err
6876
}
6977

70-
var jsonEncoded []byte
71-
if data != nil {
72-
jsonEncoded, err = json.Marshal(data)
73-
if err != nil {
74-
return err
75-
}
78+
body, contentType, err := prepareRequestBody(data)
79+
if err != nil {
80+
return err
7681
}
7782

78-
request, err := http.NewRequest(method, uri.String(), bytes.NewBuffer(jsonEncoded))
83+
request, err := http.NewRequest(method, uri.String(), bytes.NewBuffer(body))
7984
if err != nil {
8085
return err
8186
}
8287

83-
request.Header.Set("Content-Type", "application/json")
8488
request.Header.Set("Accept", "application/json")
8589
request.Header.Set("Authorization", "AccessKey "+c.AccessKey)
8690
request.Header.Set("User-Agent", "MessageBird/ApiClient/"+ClientVersion+" Go/"+runtime.Version())
91+
if contentType != contentTypeEmpty {
92+
request.Header.Set("Content-Type", string(contentType))
93+
}
8794

8895
if c.DebugLog != nil {
8996
if data != nil {
90-
c.DebugLog.Printf("HTTP REQUEST: %s %s %s", method, uri.String(), jsonEncoded)
97+
c.DebugLog.Printf("HTTP REQUEST: %s %s %s", method, uri.String(), body)
9198
} else {
9299
c.DebugLog.Printf("HTTP REQUEST: %s %s", method, uri.String())
93100
}
@@ -136,3 +143,22 @@ func (c *Client) Request(v interface{}, method, path string, data interface{}) e
136143
return errorResponse
137144
}
138145
}
146+
147+
// prepareRequestBody takes untyped data and attempts constructing a meaningful
148+
// request body from it. It also returns the appropriate Content-Type.
149+
func prepareRequestBody(data interface{}) ([]byte, contentType, error) {
150+
switch data := data.(type) {
151+
case nil:
152+
// Nil bodies are accepted by `net/http`, so this is not an error.
153+
return nil, contentTypeEmpty, nil
154+
case string:
155+
return []byte(data), contentTypeFormURLEncoded, nil
156+
default:
157+
b, err := json.Marshal(data)
158+
if err != nil {
159+
return nil, contentType(""), err
160+
}
161+
162+
return b, contentTypeJSON, nil
163+
}
164+
}

group/group.go

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,9 @@ func AddContacts(c *messagebird.Client, groupID string, contactIDs []string) err
161161
return err
162162
}
163163

164-
query := addContactsQuery(contactIDs)
165-
formattedPath := fmt.Sprintf("%s/%s/%s?%s", path, groupID, contactPath, query)
164+
data := addContactsData(contactIDs)
166165

167-
return c.Request(nil, http.MethodGet, formattedPath, nil)
166+
return c.Request(nil, http.MethodPut, path+"/"+groupID+"/"+contactPath, data)
168167
}
169168

170169
func validateAddContacts(contactIDs []string) error {
@@ -182,19 +181,12 @@ func validateAddContacts(contactIDs []string) error {
182181
return nil
183182
}
184183

185-
// addContactsQuery gets a query string to add contacts to a group. We're using
186-
// the alternative "/foo?_method=PUT&key=value" format to send the contact IDs
187-
// as GET params. Sending these in the request body would require a painful
188-
// workaround, as client.Request sends request bodies as JSON by default. See
189-
// also: https://developers.messagebird.com/docs/alternatives.
190-
//
191-
// It should also be noted that we're intentionally not using url.Values for
192-
// building the query string: the API expects `ids[]=foo&ids[]=bar` format,
193-
// while url.Values encodes to `ids=foo&ids=bar`.
194-
func addContactsQuery(contactIDs []string) string {
195-
// Slice's length is one bigger than len(IDs) for the _method param.
196-
params := make([]string, 0, len(contactIDs)+1)
197-
params = append(params, "_method="+http.MethodPut)
184+
// addContactsData gets the data string for adding a contact to a group. We're
185+
// intentionally not using url.Values for building the string: the API expects
186+
// `ids[]=foo&ids[]=bar` format, while url.Values encodes to `ids=foo&ids=bar`.
187+
func addContactsData(contactIDs []string) string {
188+
cap := len(contactIDs)
189+
params := make([]string, 0, cap)
198190

199191
for _, contactID := range contactIDs {
200192
params = append(params, "ids[]="+contactID)

group/group_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ func TestUpdate(t *testing.T) {
159159

160160
mbtest.AssertEndpointCalled(t, http.MethodPatch, "/groups/group-id")
161161
mbtest.AssertTestdata(t, "groupRequestUpdateObject.json", mbtest.Request.Body)
162+
163+
if mbtest.Request.ContentType != "application/json" {
164+
t.Fatalf("got %s, expected application/json", mbtest.Request.ContentType)
165+
}
162166
}
163167

164168
func TestAddContacts(t *testing.T) {
@@ -169,10 +173,11 @@ func TestAddContacts(t *testing.T) {
169173
t.Fatalf("unexpected error removing Contacts from Group: %s", err)
170174
}
171175

172-
mbtest.AssertEndpointCalled(t, http.MethodGet, "/groups/group-id/contacts")
176+
mbtest.AssertEndpointCalled(t, http.MethodPut, "/groups/group-id/contacts")
177+
mbtest.AssertTestdata(t, "groupRequestAddContactsObject.txt", mbtest.Request.Body)
173178

174-
if mbtest.Request.URL.RawQuery != "_method=PUT&ids[]=first-contact-id&ids[]=second-contact-id" {
175-
t.Fatalf("got %s, expected _method=PUT&ids[]=first-contact-id&ids[]=second-contact-id", mbtest.Request.URL.RawQuery)
179+
if mbtest.Request.ContentType != "application/x-www-form-urlencoded" {
180+
t.Fatalf("got %s, expected application/x-www-form-urlencoded", mbtest.Request.ContentType)
176181
}
177182
}
178183

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ids[]=first-contact-id&ids[]=second-contact-id

internal/mbtest/test_server.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import (
1010
)
1111

1212
type request struct {
13-
Body []byte
14-
Method string
15-
URL *url.URL
13+
Body []byte
14+
ContentType, Method string
15+
URL *url.URL
1616
}
1717

1818
// Request contains the lastly received http.Request by the fake server.
@@ -35,8 +35,9 @@ func EnableServer(m *testing.M) {
3535
func initAndStartServer() {
3636
server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3737
Request = request{
38-
Method: r.Method,
39-
URL: r.URL,
38+
ContentType: r.Header.Get("Content-Type"),
39+
Method: r.Method,
40+
URL: r.URL,
4041
}
4142

4243
var err error

signature/signature.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
Package signature implements signature verification for MessageBird webhooks.
3+
4+
To use define a new validator using your MessageBird Signing key. You can use the
5+
ValidRequest method, just pass the request as a parameter:
6+
7+
validator := signature.NewValidator("your signing key")
8+
if err := validator.ValidRequest(r); err != nil {
9+
// handle error
10+
}
11+
12+
Or use the handler as a middleware for your server:
13+
14+
http.Handle("/path", validator.Validate(YourHandler))
15+
16+
It will reject the requests that contain invalid signatures.
17+
The validator uses a 5ms seconds window to accept requests as valid, to change
18+
this value, set the ValidityWindow to the disired duration.
19+
Take into account that the validity window works around the current time:
20+
[now - ValidityWindow/2, now + ValidityWindow/2]
21+
*/
22+
package signature
23+
24+
import (
25+
"bytes"
26+
"crypto/hmac"
27+
"crypto/sha256"
28+
"encoding/base64"
29+
"fmt"
30+
"io/ioutil"
31+
"net/http"
32+
"net/url"
33+
"strconv"
34+
"time"
35+
)
36+
37+
const (
38+
tsHeader = "MessageBird-Request-Timestamp"
39+
sHeader = "MessageBird-Signature"
40+
)
41+
42+
// ValidityWindow defines the time window in which to validate a request.
43+
var ValidityWindow = 5 * time.Second
44+
45+
// StringToTime converts from Unicode Epoch encoded timestamps to the time.Time type.
46+
func stringToTime(s string) (time.Time, error) {
47+
sec, err := strconv.ParseInt(s, 10, 64)
48+
if err != nil {
49+
return time.Time{}, err
50+
}
51+
return time.Unix(sec, 0), nil
52+
}
53+
54+
// Validator type represents a MessageBird signature validator.
55+
type Validator struct {
56+
SigningKey string // Signing Key provided by MessageBird.
57+
}
58+
59+
// NewValidator returns a signature validator object.
60+
func NewValidator(signingKey string) *Validator {
61+
return &Validator{
62+
SigningKey: signingKey,
63+
}
64+
}
65+
66+
// validTimestamp validates if the MessageBird-Request-Timestamp is a valid
67+
// date and if the request is older than the validator Period.
68+
func (v *Validator) validTimestamp(ts string) bool {
69+
t, err := stringToTime(ts)
70+
if err != nil {
71+
return false
72+
}
73+
diff := time.Now().Add(ValidityWindow / 2).Sub(t)
74+
return diff < ValidityWindow && diff > 0
75+
}
76+
77+
// calculateSignature calculates the MessageBird-Signature using HMAC_SHA_256
78+
// encoding and the timestamp, query params and body from the request:
79+
// signature = HMAC_SHA_256(
80+
// TIMESTAMP + \n + QUERY_PARAMS + \n + SHA_256_SUM(BODY),
81+
// signing_key)
82+
func (v *Validator) calculateSignature(ts, qp string, b []byte) ([]byte, error) {
83+
var m bytes.Buffer
84+
bh := sha256.Sum256(b)
85+
fmt.Fprintf(&m, "%s\n%s\n%s", ts, qp, bh[:])
86+
mac := hmac.New(sha256.New, []byte(v.SigningKey))
87+
if _, err := mac.Write(m.Bytes()); err != nil {
88+
return nil, err
89+
}
90+
return mac.Sum(nil), nil
91+
}
92+
93+
// validSignature takes the timestamp, query params and body from the request,
94+
// calculates the expected signature and compares it to the one sent by MessageBird.
95+
func (v *Validator) validSignature(ts, rqp string, b []byte, rs string) bool {
96+
uqp, err := url.Parse("?" + rqp)
97+
if err != nil {
98+
return false
99+
}
100+
es, err := v.calculateSignature(ts, uqp.Query().Encode(), b)
101+
if err != nil {
102+
return false
103+
}
104+
drs, err := base64.StdEncoding.DecodeString(rs)
105+
if err != nil {
106+
return false
107+
}
108+
return hmac.Equal(drs, es)
109+
}
110+
111+
// ValidRequest is a method that takes care of the signature validation of
112+
// incoming requests.
113+
func (v *Validator) ValidRequest(r *http.Request) error {
114+
ts := r.Header.Get(tsHeader)
115+
rs := r.Header.Get(sHeader)
116+
if ts == "" || rs == "" {
117+
return fmt.Errorf("Unknown host: %s", r.Host)
118+
}
119+
b, _ := ioutil.ReadAll(r.Body)
120+
if v.validTimestamp(ts) == false || v.validSignature(ts, r.URL.RawQuery, b, rs) == false {
121+
return fmt.Errorf("Unknown host: %s", r.Host)
122+
}
123+
r.Body = ioutil.NopCloser(bytes.NewBuffer(b))
124+
return nil
125+
}
126+
127+
// Validate is a handler wrapper that takes care of the signature validation of
128+
// incoming requests and rejects them if invalid or pass them on to your handler
129+
// otherwise.
130+
func (v *Validator) Validate(h http.Handler) http.Handler {
131+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132+
if err := v.ValidRequest(r); err != nil {
133+
http.Error(w, "", http.StatusUnauthorized)
134+
return
135+
}
136+
h.ServeHTTP(w, r)
137+
})
138+
}

0 commit comments

Comments
 (0)