Skip to content

Commit f4587e0

Browse files
Func: add new ValidRequest method
1 parent e3f6a8a commit f4587e0

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

signature/signature.go

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"encoding/base64"
88
"fmt"
99
"io/ioutil"
10-
"log"
1110
"math"
1211
"net/http"
1312
"net/url"
@@ -42,19 +41,15 @@ func hMACSHA256(message, key []byte) ([]byte, error) {
4241

4342
// Validator type represents a MessageBird signature validator
4443
type Validator struct {
45-
SigningKey string // Signing Key provided by MessageBird
46-
Period ValidityPeriod // Period in hours for a message to be accepted as real, set to nil to bypass the timestamp validator
47-
Log *log.Logger
48-
LogMesssage *string
44+
SigningKey string // Signing Key provided by MessageBird
45+
Period ValidityPeriod // Period in hours for a message to be accepted as real, set to nil to bypass the timestamp validator
4946
}
5047

5148
// NewValidator returns a signature validator object
52-
func NewValidator(signingKey string, period ValidityPeriod, log *log.Logger, message *string) *Validator {
49+
func NewValidator(signingKey string, period ValidityPeriod) *Validator {
5350
return &Validator{
54-
SigningKey: signingKey,
55-
Period: period,
56-
Log: log,
57-
LogMesssage: message,
51+
SigningKey: signingKey,
52+
Period: period,
5853
}
5954
}
6055

@@ -110,29 +105,36 @@ func (v *Validator) ValidSignature(ts, rqp string, b []byte, rs string) bool {
110105
}
111106

112107
func (v *Validator) Error(w http.ResponseWriter, r *http.Request) {
113-
if v.Log != nil {
114-
v.Log.Printf("%s, sending host: %s", *v.LogMesssage, r.Host)
108+
109+
}
110+
111+
// ValidRequest is a method that takes care of the signature validation of
112+
// incoming requests
113+
// To use just wrap your handler with it:
114+
// signature.Validate(request)
115+
func (v *Validator) ValidRequest(r *http.Request) (bool, error) {
116+
ts := r.Header.Get(tsHeader)
117+
rs := r.Header.Get(sHeader)
118+
if ts == "" || rs == "" {
119+
return false, fmt.Errorf("Unknown host: %s", r.Host)
115120
}
116-
http.Error(w, "Request not allowed", http.StatusUnauthorized)
117-
return
121+
b, _ := ioutil.ReadAll(r.Body)
122+
if v.ValidTimestamp(ts) == false || v.ValidSignature(ts, r.URL.RawQuery, b, rs) == false {
123+
return false, fmt.Errorf("Unknown host: %s", r.Host)
124+
}
125+
r.Body = ioutil.NopCloser(bytes.NewBuffer(b))
126+
return true, nil
118127
}
119128

120129
// Validate is a handler wrapper that takes care of the signature validation of
121130
// incoming requests and rejects them if invalid or pass them on to your handler
122131
// otherwise.
123-
// To use just wrappe your handler with it:
132+
// To use just wrap your handler with it:
124133
// http.Handle("/path", signature.Validate(handleThing))
125134
func (v *Validator) Validate(h http.Handler) http.Handler {
126135
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
127-
ts := r.Header.Get(tsHeader)
128-
rs := r.Header.Get(sHeader)
129-
if ts == "" || rs == "" {
130-
v.Error(w, r)
131-
return
132-
}
133-
b, _ := ioutil.ReadAll(r.Body)
134-
if v.ValidTimestamp(ts) == false || v.ValidSignature(ts, r.URL.RawQuery, b, rs) == false {
135-
v.Error(w, r)
136+
if res, _ := v.ValidRequest(r); res == false {
137+
http.Error(w, "", http.StatusUnauthorized)
136138
return
137139
}
138140
h.ServeHTTP(w, r)

signature/signature_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func TestCalculateSignature(t *testing.T) {
6868
},
6969
}
7070
for i, tt := range cases {
71-
v := NewValidator(tt.sKey, nil, nil, nil)
71+
v := NewValidator(tt.sKey, nil)
7272
s, err := v.CalculateSignature(tt.ts, tt.qp, []byte(tt.b))
7373
if err != nil {
7474
t.Errorf("Error calculating signature: %s, expected: %s", s, tt.es)
@@ -122,7 +122,7 @@ func TestValidTimestamp(t *testing.T) {
122122
}
123123

124124
for i, tt := range cases {
125-
v := NewValidator(testKey, tt.p, nil, nil)
125+
v := NewValidator(testKey, tt.p)
126126
r := v.ValidTimestamp(tt.ts)
127127
if r != tt.e {
128128
t.Errorf("Unexpected error validating ts: %s, test case: %d", tt.ts, i)
@@ -162,7 +162,7 @@ func TestValidSignature(t *testing.T) {
162162
}
163163

164164
for i, tt := range cases {
165-
v := NewValidator(testKey, nil, nil, nil)
165+
v := NewValidator(testKey, nil)
166166
r := v.ValidSignature(tt.ts, tt.qp, []byte(tt.b), tt.s)
167167
if r != tt.e {
168168
t.Errorf("Unexpected error validating signature: %s, test case: %d", tt.s, i)
@@ -229,7 +229,7 @@ func TestValidate(t *testing.T) {
229229
}
230230

231231
for i, tt := range cases {
232-
v := NewValidator(tt.k, nil, nil, nil)
232+
v := NewValidator(tt.k, nil)
233233
ts := httptest.NewServer(v.Validate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
234234
w.WriteHeader(http.StatusOK)
235235
})))

0 commit comments

Comments
 (0)