-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathapi.go
More file actions
393 lines (367 loc) · 12.3 KB
/
api.go
File metadata and controls
393 lines (367 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
package main
import (
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"os"
"strconv"
"strings"
"time"
"golang.org/x/net/idna"
"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/db"
"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
"github.com/getsentry/raven-go"
)
////////////////////////////////
// ***** REST API ***** //
////////////////////////////////
// Minimum time to cache each domain scan
const cacheScanTime = time.Minute
// Type for performing checks against an input domain. Returns
// a DomainResult object from the checker.
type checkPerformer func(API, string) (checker.DomainResult, error)
// API is the HTTP API that this service provides.
// All requests respond with an APIResponse JSON, with fields:
// {
// status_code // HTTP status code of request
// message // Any error message accompanying the status_code. If 200, empty.
// response // Response data (as JSON) from this request.
// }
// Any POST request accepts either URL query parameters or data value parameters,
// and prefers the latter if both are present.
type API struct {
Database *db.SQLDatabase
CheckDomain checkPerformer
List PolicyList
DontScan map[string]bool
Emailer EmailSender
Templates map[string]*template.Template
}
// PolicyList interface wraps a policy-list like structure.
// The most important query you can perform is to fetch the policy
// for a particular domain.
type PolicyList interface {
HasDomain(string) bool
Raw() policy.List
}
// EmailSender interface wraps a back-end that can send e-mails.
type EmailSender interface {
// SendValidation sends a validation e-mail for a particular domain,
// with a particular validation token.
SendValidation(*models.PolicySubmission, string) error
}
// APIResponse wraps all the responses from this API.
type APIResponse struct {
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
templateName string `json:"-"`
}
type apiHandler func(r *http.Request) APIResponse
func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
response := handler(r)
if response.StatusCode == http.StatusInternalServerError {
packet := raven.NewPacket(response.Message, raven.NewHttp(r))
raven.Capture(packet, nil)
}
if strings.Contains(r.Header.Get("accept"), "text/html") {
api.writeHTML(w, response)
} else {
api.writeJSON(w, response)
}
}
}
func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := models.PolicySubmission{Name: domain}.AsyncPolicyListCheck(api.Database.PendingPolicies, api.Database.Policies, api.List)
c := checker.Checker{
Cache: &checker.ScanCache{
ScanStore: api.Database,
ExpireTime: 5 * time.Minute,
},
Timeout: 3 * time.Second,
}
result := c.CheckDomain(domain, nil)
policyResult := <-policyChan
result.ExtraResults["policylist"] = &policyResult
return result, nil
}
// Scan is the handler for /api/scan.
// POST /api/scan
// domain: Mail domain to scan.
// Scans domain and returns data from it.
// GET /api/scan?domain=<domain>
// Retrieves most recent scan for domain.
// Both set a models.Scan JSON as the response.
func (api API) Scan(r *http.Request) APIResponse {
domain, err := getASCIIDomain(r)
if err != nil {
return APIResponse{StatusCode: http.StatusBadRequest, Message: err.Error()}
}
// Check if we shouldn't scan this domain
if api.DontScan != nil {
if _, ok := api.DontScan[domain]; ok {
return APIResponse{StatusCode: http.StatusTooManyRequests}
}
}
// POST: Force scan to be conducted
if r.Method == http.MethodPost {
// 0. If last scan was recent and on same scan version, return cached scan.
scan, err := api.Database.GetLatestScan(domain)
if err == nil && scan.Version == models.ScanVersion &&
time.Now().Before(scan.Timestamp.Add(cacheScanTime)) {
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
templateName: "scan",
}
}
// 1. Conduct scan via starttls-checker
scanData, err := api.CheckDomain(api, domain)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
scan = models.Scan{
Domain: domain,
Data: scanData,
Timestamp: time.Now(),
Version: models.ScanVersion,
}
// 2. Put scan into DB
err = api.Database.PutScan(scan)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
templateName: "scan",
}
// GET: Just fetch the most recent scan
} else if r.Method == http.MethodGet {
scan, err := api.Database.GetLatestScan(domain)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: scan}
} else {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/scan only accepts POST and GET requests"}
}
}
// MaxHostnames is the maximum number of hostnames that can be specified for a single domain's TLS policy.
const MaxHostnames = 8
// Extracts relevant parameters from http.Request for a POST to /api/queue into PolicySubmission
// If MTASTS is set, doesn't try to extract hostnames. Otherwise, expects between 1 and MaxHostnames
// valid hostnames to be given in |r|.
func getDomainParams(r *http.Request) (models.PolicySubmission, error) {
name, err := getASCIIDomain(r)
if err != nil {
return models.PolicySubmission{}, err
}
email, err := getParam("email", r)
if err != nil {
email = validationAddress(name)
}
mtasts := r.FormValue("mta-sts")
domain := models.PolicySubmission{
Name: name,
Email: email,
MTASTS: mtasts == "on",
}
if !domain.MTASTS {
p := policy.TLSPolicy{Mode: "testing", MXs: make([]string, 0)}
for _, hostname := range r.PostForm["hostnames"] {
if len(hostname) == 0 {
continue
}
if !validDomainName(strings.TrimPrefix(hostname, ".")) {
return domain, fmt.Errorf("Hostname %s is invalid", hostname)
}
p.MXs = append(p.MXs, hostname)
}
if len(p.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", name)
}
if len(p.MXs) > MaxHostnames {
return domain, fmt.Errorf("No more than 8 MX hostnames are permitted")
}
domain.Policy = &p
}
return domain, nil
}
// Queue is the handler for /api/queue
// POST /api/queue?domain=<domain>
// domain: Mail domain to queue a TLS policy for.
// mta_sts: "on" if domain supports MTA-STS, else "".
// hostnames: List of MX hostnames to put into this domain's TLS policy. Up to 8.
// Sets models.PolicySubmission object as response.
// weeks (optional, default 4): How many weeks is this domain queued for.
// email (optional): Contact email associated with domain.
// GET /api/queue?domain=<domain>
// Sets models.Domain object as response.
func (api API) Queue(r *http.Request) APIResponse {
// POST: Insert this domain into the queue
if r.Method == http.MethodPost {
domain, err := getDomainParams(r)
if err != nil {
return badRequest(err.Error())
}
if !domain.CanUpdate(api.Database.Policies) {
return badRequest("existing submission can't be updated")
}
ok, msg := domain.HasValidScan(api.Database)
if !ok {
return badRequest(msg)
}
token, err := domain.InitializeWithToken(api.Database.PendingPolicies, api.Database)
if err != nil {
return serverError(err.Error())
}
if err = api.Emailer.SendValidation(&domain, token); err != nil {
log.Print(err)
return serverError("Unable to send validation e-mail")
}
return APIResponse{
StatusCode: http.StatusOK,
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain.Name),
}
}
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST requests"}
}
// Validate handles requests to /api/validate
// POST /api/validate
// token: token to validate/redeem
// Sets the queued domain name as response.
func (api API) Validate(r *http.Request) APIResponse {
token, err := getParam("token", r)
if err != nil {
return APIResponse{StatusCode: http.StatusBadRequest, Message: err.Error()}
}
if r.Method != http.MethodPost {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/validate only accepts POST requests"}
}
tokenData := models.Token{Token: token}
domain, userErr, dbErr := tokenData.Redeem(api.Database.PendingPolicies, api.Database.Policies, api.Database)
if userErr != nil {
return badRequest(userErr.Error())
}
if dbErr != nil {
return serverError(dbErr.Error())
}
return APIResponse{StatusCode: http.StatusOK, Response: domain}
}
// Retrieve "domain" parameter from request as ASCII
// If fails, returns an error.
func getASCIIDomain(r *http.Request) (string, error) {
domain, err := getParam("domain", r)
if err != nil {
return domain, err
}
ascii, err := idna.ToASCII(domain)
if err != nil {
return "", fmt.Errorf("could not convert domain %s to ASCII (%s)", domain, err)
}
return ascii, nil
}
// Retrieves and lowercases `param` as a query parameter from `http.Request` r.
// If fails, then returns an error.
func getParam(param string, r *http.Request) (string, error) {
unicode := r.FormValue(param)
if unicode == "" {
return "", fmt.Errorf("query parameter %s not specified", param)
}
return strings.ToLower(unicode), nil
}
// Retrieves `param` as a query parameter from `http.Request` r, and tries to cast it as
// a number between [lowInc, highExc). If fails, then returns an error.
// If `param` isn't specified, return defaultNum.
func getInt(param string, r *http.Request, lowInc int, highExc int, defaultNum int) (int, error) {
unicode := r.FormValue(param)
if unicode == "" {
return defaultNum, nil
}
n, err := strconv.Atoi(unicode)
if err != nil {
return -1, err
}
if n < lowInc {
return n, fmt.Errorf("expected query parameter %s to be more than or equal to %d, was %d", param, lowInc, n)
}
if n >= highExc {
return n, fmt.Errorf("expected query parameter %s to be less than %d, was %d", param, highExc, n)
}
return n, nil
}
// Writes `v` as a JSON object to http.ResponseWriter `w`. If an error
// occurs, writes `http.StatusInternalServerError` to `w`.
func (api *API) writeJSON(w http.ResponseWriter, apiResponse APIResponse) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(apiResponse.StatusCode)
b, err := json.MarshalIndent(apiResponse, "", " ")
if err != nil {
msg := fmt.Sprintf("Internal error: could not format JSON. (%s)\n", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
fmt.Fprintf(w, "%s\n", b)
}
func (api *API) parseTemplates() {
names := []string{"default", "scan"}
api.Templates = make(map[string]*template.Template)
for _, name := range names {
path := fmt.Sprintf("views/%s.html.tmpl", name)
tmpl, err := template.ParseFiles(path)
if err != nil {
raven.CaptureError(err, nil)
log.Fatal(err)
}
api.Templates[name] = tmpl
}
}
func (api *API) writeHTML(w http.ResponseWriter, apiResponse APIResponse) {
// Add some additional useful fields for use in templates.
data := struct {
APIResponse
BaseURL string
StatusText string
}{
APIResponse: apiResponse,
BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"),
StatusText: http.StatusText(apiResponse.StatusCode),
}
if apiResponse.templateName == "" {
apiResponse.templateName = "default"
}
tmpl, ok := api.Templates[apiResponse.templateName]
if !ok {
err := fmt.Errorf("Template not found: %s", apiResponse.templateName)
raven.CaptureError(err, nil)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(apiResponse.StatusCode)
err := tmpl.Execute(w, data)
if err != nil {
log.Println(err)
raven.CaptureError(err, nil)
}
}
func badRequest(format string, a ...interface{}) APIResponse {
return APIResponse{
StatusCode: http.StatusBadRequest,
Message: fmt.Sprintf(format, a...),
}
}
func serverError(format string, a ...interface{}) APIResponse {
return APIResponse{
StatusCode: http.StatusInternalServerError,
Message: fmt.Sprintf(format, a...),
}
}