Skip to content

Commit 6acdc9b

Browse files
Juanita De La CuestaJuanita De La Cuesta
authored andcommitted
Func: Rearreange qp before calculating
1 parent b38424b commit 6acdc9b

File tree

2 files changed

+255
-18
lines changed

2 files changed

+255
-18
lines changed

signature/signature.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@ import (
88
"fmt"
99
"io/ioutil"
1010
"log"
11+
"math"
1112
"net/http"
13+
"net/url"
1214
"strconv"
1315
"time"
1416
)
1517

18+
const tsHeader = "MessageBird-Request-Timestamp"
19+
const sHeader = "MessageBird-Signature"
20+
1621
// ValidityPeriod is the time in hours after which a request is descarded
1722
type ValidityPeriod *float64
1823

@@ -44,10 +49,10 @@ type Validator struct {
4449
}
4550

4651
// NewValidator returns a signature validator object
47-
func NewValidator(signingKey string, period float64, log *log.Logger, message *string) *Validator {
52+
func NewValidator(signingKey string, period ValidityPeriod, log *log.Logger, message *string) *Validator {
4853
return &Validator{
4954
SigningKey: signingKey,
50-
Period: &period,
55+
Period: period,
5156
Log: log,
5257
LogMesssage: message,
5358
}
@@ -56,11 +61,14 @@ func NewValidator(signingKey string, period float64, log *log.Logger, message *s
5661
// ValidTimestamp validates if the MessageBird-Request-Timestamp is a valid
5762
// date and if the request is older than the validator Period.
5863
func (v *Validator) ValidTimestamp(ts string) bool {
64+
t, err := StringToTime(ts)
65+
if err != nil {
66+
return false
67+
}
5968
if v.Period != nil {
6069
now := time.Now()
61-
t, err := StringToTime(ts)
6270
diff := now.Sub(t)
63-
if err != nil || diff.Hours() > *v.Period {
71+
if math.Abs(diff.Hours()) > *v.Period {
6472
return false
6573
}
6674
}
@@ -85,15 +93,16 @@ func (v *Validator) CalculateSignature(ts, qp string, b []byte) ([]byte, error)
8593

8694
// ValidSignature takes the timestamp, query params and body from the request,
8795
// calculates the expected signature and compares it to the one sent by MessageBird.
88-
func (v *Validator) ValidSignature(ts, qp, rs string, b []byte) bool {
89-
es, _ := v.CalculateSignature(ts, qp, b)
96+
func (v *Validator) ValidSignature(ts, rqp string, b []byte, rs string) bool {
97+
uqp, _ := url.Parse("?" + rqp)
98+
es, _ := v.CalculateSignature(ts, uqp.Query().Encode(), b)
9099
drs, _ := base64.StdEncoding.DecodeString(rs)
91100
return hmac.Equal(drs, es)
92101
}
93102

94103
func (v *Validator) Error(w http.ResponseWriter, r *http.Request) {
95104
if v.Log != nil {
96-
v.Log.Println(v.LogMesssage, r.Host)
105+
v.Log.Printf("%s, sending host: %s", *v.LogMesssage, r.Host)
97106
}
98107
http.Error(w, "Request not allowed", http.StatusUnauthorized)
99108
return
@@ -103,17 +112,17 @@ func (v *Validator) Error(w http.ResponseWriter, r *http.Request) {
103112
// incoming requests and rejects them if invalid or pass them on to your handler
104113
// otherwise.
105114
// To use just wrappe your handler with it:
106-
// http.Handle("/path", signature.Validate(handleThing))
115+
// http.Handle("/path", signature.Validate(handleThing))
107116
func (v *Validator) Validate(h http.Handler) http.Handler {
108117
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109-
ts := r.Header.Get("MessageBird-Request-Timestamp")
110-
rs := r.Header.Get("MessageBird-Request-Signature")
118+
ts := r.Header.Get(tsHeader)
119+
rs := r.Header.Get(sHeader)
111120
if ts == "" || rs == "" {
112121
v.Error(w, r)
113122
return
114123
}
115124
b, _ := ioutil.ReadAll(r.Body)
116-
if v.ValidTimestamp(ts) == false || v.ValidSignature(ts, r.URL.RawQuery, rs, b) == false {
125+
if v.ValidTimestamp(ts) == false || v.ValidSignature(ts, r.URL.RawQuery, b, rs) == false {
117126
v.Error(w, r)
118127
return
119128
}

signature/signature_test.go

Lines changed: 235 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,250 @@ package signature
33
import (
44
"bytes"
55
"encoding/base64"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"strings"
610
"testing"
11+
"time"
712
)
813

914
const testTs = "1544544948"
1015
const testQp = "abc=foo&def=bar"
1116
const testBody = `{"a key":"some value"}`
1217
const testSignature = "orb0adPhRCYND1WCAvPBr+qjm4STGtyvNDIDNBZ4Ir4="
18+
const testKey = "other-secret"
1319

1420
func TestCalculateSignature(t *testing.T) {
15-
v := NewValidator("other-secret", 2, nil, nil)
16-
s, err := v.CalculateSignature(testTs, testQp, []byte(testBody))
17-
if err != nil {
18-
t.Errorf("Error calculating signature: %s, expected: orb0adPhRCYND1WCAvPBr+qjm4STGtyvNDIDNBZ4Ir4=", s)
21+
var cases = []struct {
22+
sKey string
23+
ts string
24+
qp string
25+
b string
26+
es string
27+
e bool
28+
}{
29+
{
30+
sKey: testKey,
31+
ts: testTs,
32+
qp: testQp,
33+
b: testBody,
34+
es: testSignature,
35+
e: true,
36+
},
37+
{
38+
sKey: testKey,
39+
ts: testTs,
40+
qp: testQp,
41+
b: testBody,
42+
es: "LISw4Je7n0/MkYDgVSzTJm8dW6BkytKTXMZZk1IElMs=",
43+
e: false,
44+
},
45+
{
46+
sKey: "secret",
47+
ts: testTs,
48+
qp: "",
49+
b: "",
50+
es: "LISw4Je7n0/MkYDgVSzTJm8dW6BkytKTXMZZk1IElMs=",
51+
e: true,
52+
},
53+
{
54+
sKey: "secret",
55+
ts: testTs,
56+
qp: "",
57+
b: testBody,
58+
es: "p2e20OtAg39DEmz1ORHpjQ556U4o1ZaH4NWbM9Q8Qjk=",
59+
e: true,
60+
},
61+
{
62+
sKey: "secret",
63+
ts: testTs,
64+
qp: testQp,
65+
b: "",
66+
es: "Tfn+nRUBsn6lQgf6IpxBMS1j9lm7XsGjt5xh47M3jCk=",
67+
e: true,
68+
},
1969
}
20-
drs, _ := base64.StdEncoding.DecodeString(testSignature)
21-
if bytes.Compare(s, drs) != 0 {
22-
t.Errorf("Unexpected signature: %s, expected: orb0adPhRCYND1WCAvPBr+qjm4STGtyvNDIDNBZ4Ir4=", s)
70+
for i, tt := range cases {
71+
v := NewValidator(tt.sKey, nil, nil, nil)
72+
s, err := v.CalculateSignature(tt.ts, tt.qp, []byte(tt.b))
73+
if err != nil {
74+
t.Errorf("Error calculating signature: %s, expected: %s", s, tt.es)
75+
}
76+
drs, _ := base64.StdEncoding.DecodeString(tt.es)
77+
e := bool(bytes.Compare(s, drs) == 0)
78+
if e != tt.e {
79+
t.Errorf("Unexpected signature: %s, test case: %d", s, i)
80+
}
2381
}
2482
}
83+
func TestValidTimestamp(t *testing.T) {
84+
var p float64 = 2
85+
now := time.Now()
86+
nowts := fmt.Sprintf("%d", now.Unix())
87+
var cases = []struct {
88+
ts string
89+
p ValidityPeriod
90+
e bool
91+
}{
92+
{
93+
ts: nowts,
94+
p: nil,
95+
e: true,
96+
},
97+
{
98+
ts: "",
99+
p: nil,
100+
e: false,
101+
},
102+
{
103+
ts: "wrongTs",
104+
p: nil,
105+
e: false,
106+
},
107+
{
108+
ts: nowts,
109+
p: &p,
110+
e: true,
111+
},
112+
{
113+
ts: fmt.Sprintf("%d", now.AddDate(0, 0, 1).Unix()),
114+
p: &p,
115+
e: false,
116+
},
117+
{
118+
ts: fmt.Sprintf("%d", now.AddDate(0, 0, -1).Unix()),
119+
p: &p,
120+
e: false,
121+
},
122+
}
123+
124+
for i, tt := range cases {
125+
v := NewValidator(testKey, tt.p, nil, nil)
126+
r := v.ValidTimestamp(tt.ts)
127+
if r != tt.e {
128+
t.Errorf("Unexpected error validating ts: %s, test case: %d", tt.ts, i)
129+
}
130+
}
131+
}
132+
133+
func TestValidSignature(t *testing.T) {
134+
var cases = []struct {
135+
ts string
136+
qp string
137+
b string
138+
s string
139+
e bool
140+
}{
141+
{
142+
ts: testTs,
143+
qp: testQp,
144+
b: testBody,
145+
s: testSignature,
146+
e: true,
147+
},
148+
{
149+
ts: testTs,
150+
qp: "def=bar&abc=foo",
151+
b: testBody,
152+
s: testSignature,
153+
e: true,
154+
},
155+
{
156+
ts: testTs,
157+
qp: testQp,
158+
b: testBody,
159+
s: "wrong signature",
160+
e: false,
161+
},
162+
}
163+
164+
for i, tt := range cases {
165+
v := NewValidator(testKey, nil, nil, nil)
166+
r := v.ValidSignature(tt.ts, tt.qp, []byte(tt.b), tt.s)
167+
if r != tt.e {
168+
t.Errorf("Unexpected error validating signature: %s, test case: %d", tt.s, i)
169+
}
170+
}
171+
}
172+
173+
func testHandler(w http.ResponseWriter, r *http.Request) {
174+
175+
}
176+
func TestValidate(t *testing.T) {
177+
var cases = []struct {
178+
k string
179+
ts string
180+
s string
181+
sh string
182+
tsh string
183+
e int
184+
}{
185+
{
186+
k: testKey,
187+
ts: testTs,
188+
s: testSignature,
189+
sh: sHeader,
190+
tsh: tsHeader,
191+
e: http.StatusOK,
192+
},
193+
{
194+
k: "",
195+
ts: testTs,
196+
s: testSignature,
197+
sh: sHeader,
198+
tsh: tsHeader,
199+
e: http.StatusUnauthorized,
200+
},
201+
{
202+
k: testKey,
203+
ts: "",
204+
s: testSignature,
205+
sh: sHeader,
206+
tsh: tsHeader,
207+
e: http.StatusUnauthorized,
208+
},
209+
{
210+
k: testKey,
211+
ts: testTs,
212+
s: "",
213+
sh: sHeader,
214+
tsh: tsHeader,
215+
e: http.StatusUnauthorized,
216+
},
217+
{
218+
k: testKey,
219+
ts: testTs,
220+
s: testSignature,
221+
sh: "wrong-header",
222+
tsh: tsHeader,
223+
e: http.StatusUnauthorized,
224+
},
225+
{
226+
k: testKey,
227+
ts: testTs,
228+
s: testSignature,
229+
sh: sHeader
230+
tsh: "wrong-header",
231+
e: http.StatusUnauthorized,
232+
},
233+
}
234+
235+
for i, tt := range cases {
236+
v := NewValidator(tt.k, nil, nil, nil)
237+
ts := httptest.NewServer(v.Validate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
238+
w.WriteHeader(http.StatusOK)
239+
})))
240+
defer ts.Close()
241+
242+
client := &http.Client{}
243+
req, _ := http.NewRequest("GET", ts.URL+"?"+testQp, strings.NewReader(testBody))
244+
req.Header.Set(tt.sh, tt.s)
245+
req.Header.Set(tt.tsh, tt.ts)
246+
res, _ := client.Do(req)
247+
if res.StatusCode != tt.e {
248+
t.Errorf("Unexpected response code: %s, test case: %d", res.Status, i)
249+
}
250+
}
251+
252+
}

0 commit comments

Comments
 (0)