Skip to content

Commit c1551af

Browse files
author
khanh.nguyen
committed
Support new webhook signature method
1 parent cb04299 commit c1551af

File tree

5 files changed

+270
-278
lines changed

5 files changed

+270
-278
lines changed

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ module github.com/messagebird/go-rest-api/v7
22

33
go 1.14
44

5-
require github.com/stretchr/testify v1.7.0
5+
require (
6+
github.com/dgrijalva/jwt-go v3.2.0+incompatible
7+
github.com/stretchr/testify v1.7.0
8+
)

go.sum

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
22
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3+
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
4+
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
35
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
46
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5-
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
67
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
78
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
89
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

signature/claims.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package signature
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
"time"
7+
8+
"github.com/dgrijalva/jwt-go"
9+
)
10+
11+
// maxSkew is the maximum time skew that we accept. Sometimes the Internet is *so* fast
12+
// that messages are received before they are sent, or the clocks of two servers are not
13+
// in-sync, whichever cause seems more likely to you.
14+
const maxSkew = 1 * time.Second
15+
16+
// Claims replaces jwt.StandardClaims as it checks all aspects of the the JWT token that
17+
// have been specified by the MessageBird RFC.
18+
type Claims struct {
19+
// The following 3 fields are added to Claims before JWT is parsed so that the
20+
// immediately following call to Valid() by jwt-go has *all* necessary information to
21+
// determine whether JWT is valid. These fields should not be overwritten by JSON
22+
// unmarshal.
23+
receivedTime time.Time `json:"-"`
24+
correctPayloadHash string `json:"-"`
25+
correctURLHash string `json:"-"`
26+
27+
Issuer string `json:"iss"`
28+
IssuedAt int64 `json:"iat"`
29+
ExpirationTime int64 `json:"exp"`
30+
JWTID string `json:"jti"`
31+
URLHash string `json:"url_hash"`
32+
PayloadHash string `json:"payload_hash,omitempty"`
33+
}
34+
35+
// Valid is called by jwt-go after the Claims struct has been filled. If an error is
36+
// returned, it means that the JWT should not be trusted.
37+
func (c Claims) Valid() error {
38+
var errs []string
39+
40+
if c.Issuer != "MessageBird" {
41+
errs = append(errs, "wrong iss")
42+
}
43+
44+
if iat := time.Unix(c.IssuedAt, int64(c.receivedTime.Nanosecond())).Add(-maxSkew); c.receivedTime.Before(iat) {
45+
errs = append(errs, "iat is in the future")
46+
}
47+
48+
if exp := time.Unix(c.ExpirationTime, int64(c.receivedTime.Nanosecond())).Add(maxSkew); c.receivedTime.After(exp) {
49+
errs = append(errs, "exp is in the past")
50+
}
51+
52+
if c.JWTID == "" {
53+
errs = append(errs, "jti is empty or missing")
54+
}
55+
56+
if c.correctURLHash != c.URLHash {
57+
errs = append(errs, "url_hash is invalid")
58+
}
59+
60+
switch {
61+
case c.correctPayloadHash == "" && c.PayloadHash != "":
62+
errs = append(errs, "payload_hash was set; expected no payload value")
63+
case c.correctPayloadHash != "" && c.correctPayloadHash != c.PayloadHash:
64+
errs = append(errs, "payload_hash is invalid")
65+
}
66+
67+
if len(errs) == 0 {
68+
return nil
69+
}
70+
return fmt.Errorf("%s", strings.Join(errs, "; "))
71+
}
72+
73+
// Claims satisfies jwt.Claims.
74+
var _ jwt.Claims = Claims{}

signature/signature.go

Lines changed: 58 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,98 @@
11
/*
22
Package signature implements signature verification for MessageBird webhooks.
33
4-
To use define a new validator using your MessageBird Signing key. You can use the
5-
ValidRequest method, just pass the request as a parameter:
4+
To use define a new validator using your MessageBird Signing key. You can use the
5+
ValidRequest method, just pass the request and base url as parameters:
66
7-
validator := signature.NewValidator("your signing key")
8-
if err := validator.ValidRequest(r); err != nil {
7+
validator := signature.NewValidator([]byte("your signing key"))
8+
baseUrl := "https://messagebird.io"
9+
if err := validator.ValidRequest(r, baseUrl); err != nil {
910
// handle error
1011
}
1112
1213
Or use the handler as a middleware for your server:
1314
14-
http.Handle("/path", validator.Validate(YourHandler))
15+
http.Handle("/path", validator.Validate(YourHandler, baseUrl))
1516
1617
It will reject the requests that contain invalid signatures.
17-
The validator uses a 5ms seconds window to accept requests as valid, to change
18-
this value, set the ValidityWindow to the disired duration.
19-
Take into account that the validity window works around the current time:
20-
[now - ValidityWindow/2, now + ValidityWindow/2]
2118
*/
2219
package signature
2320

2421
import (
2522
"bytes"
26-
"crypto/hmac"
2723
"crypto/sha256"
28-
"encoding/base64"
24+
"encoding/hex"
2925
"fmt"
3026
"io/ioutil"
3127
"net/http"
3228
"net/url"
33-
"strconv"
3429
"time"
35-
)
3630

37-
const (
38-
tsHeader = "MessageBird-Request-Timestamp"
39-
sHeader = "MessageBird-Signature"
31+
"github.com/dgrijalva/jwt-go"
4032
)
4133

42-
// ValidityWindow defines the time window in which to validate a request.
43-
var ValidityWindow = 5 * time.Second
34+
const signatureHeader = "MessageBird-Signature-JWT"
4435

45-
// StringToTime converts from Unicode Epoch encoded timestamps to the time.Time type.
46-
func stringToTime(s string) (time.Time, error) {
47-
sec, err := strconv.ParseInt(s, 10, 64)
48-
if err != nil {
49-
return time.Time{}, err
50-
}
51-
return time.Unix(sec, 0), nil
36+
// TimeFunc provides the current time same as time.Now but can be overridden for testing.
37+
var TimeFunc = time.Now
38+
39+
// allowedMethods lists the signing methods that we accept. We only allow symmetric-key
40+
// algorithms as our customer signing keys are currently all simple byte strings. HMAC is
41+
// also the only symkey signature method that is required by the RFC7518 Section 3.1 and
42+
// thus should be supported by all JWT implementations.
43+
var allowedMethods = []string{
44+
jwt.SigningMethodHS256.Name,
45+
jwt.SigningMethodHS384.Name,
46+
jwt.SigningMethodHS512.Name,
5247
}
5348

5449
// Validator type represents a MessageBird signature validator.
5550
type Validator struct {
56-
SigningKey string // Signing Key provided by MessageBird.
51+
SigningKey []byte // Signing Key provided by MessageBird.
5752
}
5853

5954
// NewValidator returns a signature validator object.
60-
func NewValidator(signingKey string) *Validator {
55+
func NewValidator(signingKey []byte) *Validator {
6156
return &Validator{
6257
SigningKey: signingKey,
6358
}
6459
}
6560

66-
// validTimestamp validates if the MessageBird-Request-Timestamp is a valid
67-
// date and if the request is older than the validator Period.
68-
func (v *Validator) validTimestamp(ts string) bool {
69-
t, err := stringToTime(ts)
70-
if err != nil {
71-
return false
72-
}
73-
diff := time.Now().Add(ValidityWindow / 2).Sub(t)
74-
return diff < ValidityWindow && diff > 0
75-
}
61+
// ValidSignature is a method that takes care of the signature validation of
62+
// incoming requests.
63+
func (v *Validator) ValidSignature(signature, url string, payload []byte) error {
64+
parser := jwt.Parser{ValidMethods: allowedMethods}
65+
keyFn := func(*jwt.Token) (interface{}, error) { return v.SigningKey, nil }
7666

77-
// calculateSignature calculates the MessageBird-Signature using HMAC_SHA_256
78-
// encoding and the timestamp, query params and body from the request:
79-
// signature = HMAC_SHA_256(
80-
// TIMESTAMP + \n + QUERY_PARAMS + \n + SHA_256_SUM(BODY),
81-
// signing_key)
82-
func (v *Validator) calculateSignature(ts, qp string, b []byte) ([]byte, error) {
83-
var m bytes.Buffer
84-
bh := sha256.Sum256(b)
85-
fmt.Fprintf(&m, "%s\n%s\n%s", ts, qp, bh[:])
86-
mac := hmac.New(sha256.New, []byte(v.SigningKey))
87-
if _, err := mac.Write(m.Bytes()); err != nil {
88-
return nil, err
67+
claims := Claims{
68+
receivedTime: TimeFunc(),
69+
correctURLHash: sha256Hash([]byte(url)),
8970
}
90-
return mac.Sum(nil), nil
91-
}
92-
93-
// validSignature takes the timestamp, query params and body from the request,
94-
// calculates the expected signature and compares it to the one sent by MessageBird.
95-
func (v *Validator) validSignature(ts, rqp string, b []byte, rs string) bool {
96-
uqp, err := url.Parse("?" + rqp)
97-
if err != nil {
98-
return false
71+
if payload != nil && len(payload) != 0 {
72+
claims.correctPayloadHash = sha256Hash(payload)
9973
}
100-
es, err := v.calculateSignature(ts, uqp.Query().Encode(), b)
101-
if err != nil {
102-
return false
103-
}
104-
drs, err := base64.StdEncoding.DecodeString(rs)
105-
if err != nil {
106-
return false
74+
75+
if _, err := parser.ParseWithClaims(signature, &claims, keyFn); err != nil {
76+
return fmt.Errorf("invalid jwt: %w", err)
10777
}
108-
return hmac.Equal(drs, es)
78+
79+
return nil
10980
}
11081

11182
// ValidRequest is a method that takes care of the signature validation of
11283
// incoming requests.
113-
func (v *Validator) ValidRequest(r *http.Request) error {
114-
ts := r.Header.Get(tsHeader)
115-
rs := r.Header.Get(sHeader)
116-
if ts == "" || rs == "" {
117-
return fmt.Errorf("Unknown host: %s", r.Host)
84+
func (v *Validator) ValidRequest(r *http.Request, baseUrl string) error {
85+
base, err := url.Parse(baseUrl)
86+
if err != nil {
87+
return fmt.Errorf("error parsing base url: %v", err)
88+
}
89+
signature := r.Header.Get(signatureHeader)
90+
if signature == "" {
91+
return fmt.Errorf("signature not found")
11892
}
11993
b, _ := ioutil.ReadAll(r.Body)
120-
if !v.validTimestamp(ts) || !v.validSignature(ts, r.URL.RawQuery, b, rs) {
121-
return fmt.Errorf("Unknown host: %s", r.Host)
94+
if err := v.ValidSignature(signature, base.ResolveReference(r.URL).String(), b); err != nil {
95+
return fmt.Errorf("invalid signature: %s", err.Error())
12296
}
12397
r.Body = ioutil.NopCloser(bytes.NewBuffer(b))
12498
return nil
@@ -127,12 +101,21 @@ func (v *Validator) ValidRequest(r *http.Request) error {
127101
// Validate is a handler wrapper that takes care of the signature validation of
128102
// incoming requests and rejects them if invalid or pass them on to your handler
129103
// otherwise.
130-
func (v *Validator) Validate(h http.Handler) http.Handler {
104+
func (v *Validator) Validate(h http.Handler, baseUrl string) http.Handler {
131105
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132-
if err := v.ValidRequest(r); err != nil {
106+
if err := v.ValidRequest(r, baseUrl); err != nil {
133107
http.Error(w, "", http.StatusUnauthorized)
134108
return
135109
}
136110
h.ServeHTTP(w, r)
137111
})
138112
}
113+
114+
func sha256Hash(data []byte) string {
115+
if data == nil {
116+
return ""
117+
}
118+
119+
h := sha256.Sum256(data)
120+
return hex.EncodeToString(h[:])
121+
}

0 commit comments

Comments
 (0)