Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 25 additions & 42 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type checkPerformer func(API, string) (checker.DomainResult, error)
// Any POST request accepts either URL query parameters or data value parameters,
// and prefers the latter if both are present.
type API struct {
Database db.Database
Database *db.SQLDatabase
checkDomainOverride checkPerformer
List PolicyList
DontScan map[string]bool
Expand All @@ -64,7 +64,7 @@ type PolicyList interface {
type EmailSender interface {
// SendValidation sends a validation e-mail for a particular domain,
// with a particular validation token.
SendValidation(*models.Domain, string) error
SendValidation(*models.PolicySubmission, string) error
}

type response struct {
Expand Down Expand Up @@ -117,7 +117,7 @@ func (api *API) RegisterHandlers(mux *http.ServeMux) http.Handler {
}

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := models.Domain{Name: domain}.AsyncPolicyListCheck(api.Database, api.List)
policyChan := models.PolicySubmission{Name: domain}.AsyncPolicyListCheck(api.Database.PendingPolicies, api.Database.Policies, api.List)
c := checker.Checker{
Cache: &checker.ScanCache{
ScanStore: api.Database,
Expand Down Expand Up @@ -198,47 +198,43 @@ func (api API) scan(r *http.Request) response {
// MaxHostnames is the maximum number of hostnames that can be specified for a single domain's TLS policy.
const MaxHostnames = 8

// Extracts relevant parameters from http.Request for a POST to /api/queue
// TODO: also validate hostnames as FQDNs.
func getDomainParams(r *http.Request) (models.Domain, error) {
// Extracts relevant parameters from http.Request for a POST to /api/queue into PolicySubmission
// If MTASTS is set, doesn't try to extract hostnames. Otherwise, expects between 1 and MaxHostnames
// valid hostnames to be given in |r|.
func getDomainParams(r *http.Request) (models.PolicySubmission, error) {
name, err := getASCIIDomain(r)
if err != nil {
return models.Domain{}, err
return models.PolicySubmission{}, err
}
mtasts := r.FormValue("mta-sts")
domain := models.Domain{
domain := models.PolicySubmission{
Name: name,
MTASTS: mtasts == "on",
State: models.StateUnconfirmed,
}
givenEmail, err := getParam("email", r)
if err == nil {
domain.Email = givenEmail
} else {
domain.Email = email.ValidationAddress(&domain)
domain.Email = email.ValidationAddress(name)
}
queueWeeks, err := getInt("weeks", r, 4, 52, 4)
if err != nil {
return domain, err
}
domain.QueueWeeks = queueWeeks

if mtasts != "on" {
if !domain.MTASTS {
p := policy.TLSPolicy{Mode: "testing", MXs: make([]string, 0)}
for _, hostname := range r.PostForm["hostnames"] {
if len(hostname) == 0 {
continue
}
if !util.ValidDomainName(strings.TrimPrefix(hostname, ".")) {
return domain, fmt.Errorf("Hostname %s is invalid", hostname)
}
domain.MXs = append(domain.MXs, hostname)
p.MXs = append(p.MXs, hostname)
}
if len(domain.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", domain.Name)
if len(p.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", name)
}
if len(domain.MXs) > MaxHostnames {
if len(p.MXs) > MaxHostnames {
return domain, fmt.Errorf("No more than 8 MX hostnames are permitted")
}
domain.Policy = &p
}
return domain, nil
}
Expand All @@ -248,7 +244,7 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
// domain: Mail domain to queue a TLS policy for.
// mta_sts: "on" if domain supports MTA-STS, else "".
// hostnames: List of MX hostnames to put into this domain's TLS policy. Up to 8.
// Sets models.Domain object as response.
// Sets models.PolicySubmission object as response.
// weeks (optional, default 4): How many weeks is this domain queued for.
// email (optional): Contact email associated with domain.
// GET /api/queue?domain=<domain>
Expand All @@ -260,12 +256,14 @@ func (api API) queue(r *http.Request) response {
if err != nil {
return badRequest(err.Error())
}
ok, msg, scan := domain.IsQueueable(api.Database, api.Database, api.List)
if !domain.CanUpdate(api.Database.Policies) {
return badRequest("existing submission can't be updated")
}
ok, msg := domain.HasValidScan(api.Database)
if !ok {
return badRequest(msg)
}
domain.PopulateFromScan(scan)
token, err := domain.InitializeWithToken(api.Database, api.Database)
token, err := domain.InitializeWithToken(api.Database.PendingPolicies, api.Database)
if err != nil {
return serverError(err.Error())
}
Expand All @@ -278,23 +276,8 @@ func (api API) queue(r *http.Request) response {
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain.Name),
}
}
// GET: Retrieve domain status from queue
if r.Method == http.MethodGet {
domainName, err := getASCIIDomain(r)
if err != nil {
return badRequest(err.Error())
}
domainObj, err := models.GetDomain(api.Database, domainName)
if err != nil {
return response{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return response{
StatusCode: http.StatusOK,
Response: domainObj,
}
}
return response{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
Message: "/api/queue only accepts POST requests"}
}

// Validate handles requests to /api/validate
Expand All @@ -311,7 +294,7 @@ func (api API) validate(r *http.Request) response {
Message: "/api/validate only accepts POST requests"}
}
tokenData := models.Token{Token: token}
domain, userErr, dbErr := tokenData.Redeem(api.Database, api.Database)
domain, userErr, dbErr := tokenData.Redeem(api.Database.PendingPolicies, api.Database.Policies, api.Database)
if userErr != nil {
return badRequest(userErr.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (l mockList) HasDomain(domain string) bool {
// Mock emailer
type mockEmailer struct{}

func (e mockEmailer) SendValidation(domain *models.Domain, token string) error { return nil }
func (e mockEmailer) SendValidation(domain *models.PolicySubmission, token string) error { return nil }

func testHTMLPost(path string, data url.Values, t *testing.T) ([]byte, int) {
req, err := http.NewRequest("POST", server.URL+path, strings.NewReader(data.Encode()))
Expand Down
57 changes: 13 additions & 44 deletions api/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"net/url"
"strings"
"testing"

"github.com/EFForg/starttls-backend/models"
)

func validQueueData(scan bool) url.Values {
Expand Down Expand Up @@ -78,34 +76,6 @@ func TestQueueDomainHidesToken(t *testing.T) {
}
}

func TestQueueDomainQueueWeeks(t *testing.T) {
defer teardown()

requestData := validQueueData(true)
requestData.Set("weeks", "50")
http.PostForm(server.URL+"/api/queue", requestData)
resp, _ := http.Get(server.URL + "/api/queue?domain=" + requestData.Get("domain"))

responseBody, _ := ioutil.ReadAll(resp.Body)
if !bytes.Contains(responseBody, []byte("50")) {
t.Errorf("Queueing domain should set weeks field properly")
}
}

func TestQueueDomainInvalidWeeks(t *testing.T) {
defer teardown()

requestData := validQueueData(true)
invalidWeeks := []string{"53", "3", "0", "-1", "abc", "5.5"}
for _, week := range invalidWeeks {
requestData.Set("weeks", week)
resp, _ := http.PostForm(server.URL+"/api/queue", requestData)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("Expected POST to api/queue to fail with weeks=%s.", week)
}
}
}

// Tests basic queuing workflow.
// Requests domain to be queued, and validates corresponding e-mail token.
// Domain status should then be updated to "queued".
Expand All @@ -127,20 +97,14 @@ func TestBasicQueueWorkflow(t *testing.T) {
queueDomainGetPath := server.URL + "/api/queue?domain=" + queueDomainPostData.Get("domain")
resp, _ = http.Get(queueDomainGetPath)

// 2-T. Check to see domain status was initialized to 'unvalidated'
domainBody, _ := ioutil.ReadAll(resp.Body)
domain := models.Domain{}
err := json.Unmarshal(domainBody, &response{Response: &domain})
if err != nil {
t.Fatalf("Returned invalid JSON object:%v\n", string(domainBody))
}
if domain.State != "unvalidated" {
t.Fatalf("Initial state for domains should be 'unvalidated'")
// 2-T. Check to see domain is in pending
domain, ok, err := api.Database.PendingPolicies.GetPolicy("example.com")
if err != nil || !ok {
t.Errorf("Queued domain should be in pending")
}
if len(domain.MXs) != 1 {
if len(domain.Policy.MXs) != 1 {
t.Fatalf("Domain should have loaded one hostname into policy")
}

// 3. Validate domain token
token, err := api.Database.GetTokenByDomain(queueDomainPostData.Get("domain"))
if err != nil {
Expand All @@ -154,7 +118,7 @@ func TestBasicQueueWorkflow(t *testing.T) {
}

// 3-T. Ensure response body contains domain name
domainBody, _ = ioutil.ReadAll(resp.Body)
domainBody, _ := ioutil.ReadAll(resp.Body)
var responseObj map[string]interface{}
err = json.Unmarshal(domainBody, &responseObj)
if err != nil {
Expand All @@ -179,8 +143,13 @@ func TestBasicQueueWorkflow(t *testing.T) {
if err != nil {
t.Fatalf("Returned invalid JSON object:%v\n", string(domainBody))
}
if domain.State != "queued" {
t.Fatalf("Token validation should have automatically queued domain")
_, ok, err = api.Database.Policies.GetPolicy(domain.Name)
if err != nil || !ok {
t.Errorf("Token validation should have automatically queued domain")
}
_, ok, err = api.Database.PendingPolicies.GetPolicy(domain.Name)
if ok {
t.Errorf("Token validation should have removed domain from pending")
}
}

Expand Down
1 change: 1 addition & 0 deletions db/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Database structure
8 changes: 0 additions & 8 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ type Database interface {
PutLocalStats(time.Time) (checker.AggregatedScan, error)
// Gets counts per day of hosts supporting MTA-STS for a given source.
GetStats(string) (stats.Series, error)
// Upserts domain state.
PutDomain(models.Domain) error
// Retrieves state of a domain
GetDomain(string, models.DomainState) (models.Domain, error)
// Retrieves all domains in a particular state.
GetDomains(models.DomainState) ([]models.Domain, error)
SetStatus(string, models.DomainState) error
RemoveDomain(string, models.DomainState) (models.Domain, error)
ClearTables() error
}

Expand Down
123 changes: 123 additions & 0 deletions db/policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package db

import (
"database/sql"
"fmt"
"strings"

"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
)

// PolicyDB is a database of PolicySubmissions.
type PolicyDB struct {
tableName string
conn *sql.DB
strict bool
}

func (p *PolicyDB) formQuery(query string) string {
return fmt.Sprintf(query, p.tableName, "domain, email, mta_sts, mxs, mode")
}

type scanner interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, I think-- It looks like Scanner expects a single interface{} destination, whereas here we are scanning into a varied number: ...interface{}

Scan(dest ...interface{}) error
}

func (p *PolicyDB) scanPolicy(result scanner) (models.PolicySubmission, error) {
data := models.PolicySubmission{Policy: new(policy.TLSPolicy)}
var rawMXs string
err := result.Scan(
&data.Name, &data.Email,
&data.MTASTS, &rawMXs, &data.Policy.Mode)
data.Policy.MXs = strings.Split(rawMXs, ",")
return data, err
}

// GetPolicies returns a list of policy submissions that match
// the mtasts status given.
func (p *PolicyDB) GetPolicies(mtasts bool) ([]models.PolicySubmission, error) {
rows, err := p.conn.Query(p.formQuery(
"SELECT %[2]s FROM %[1]s WHERE mta_sts=$1"), mtasts)
if err != nil {
return nil, err
}
defer rows.Close()
policies := []models.PolicySubmission{}
for rows.Next() {
policy, err := p.scanPolicy(rows)
if err != nil {
return nil, err
}
policies = append(policies, policy)
}
return policies, nil
}

// GetPolicy returns the policy submission for the given domain.
// Returns the submission (if found), whether it was found, and any errors encountered.
func (p *PolicyDB) GetPolicy(domainName string) (policy models.PolicySubmission, ok bool, err error) {
row := p.conn.QueryRow(p.formQuery(
"SELECT %[2]s FROM %[1]s WHERE domain=$1"), domainName)
result, err := p.scanPolicy(row)
if err == sql.ErrNoRows {
return result, false, nil
}
return result, true, err
}

// RemovePolicy removes the policy submission with the given domain from
// the database.
func (p *PolicyDB) RemovePolicy(domainName string) (models.PolicySubmission, error) {
row := p.conn.QueryRow(p.formQuery(
"DELETE FROM %[1]s WHERE domain=$1 RETURNING %[2]s"), domainName)
return p.scanPolicy(row)
}

// PutOrUpdatePolicy upserts the given policy into the data store, if
// CanUpdate passes.
func (p *PolicyDB) PutOrUpdatePolicy(ps *models.PolicySubmission) error {
if p.strict && !ps.CanUpdate(p) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these guards - this seems like a good place for this type of validation. It seems like we're never running this code with strict = true. Is there somewhere we should be?

return fmt.Errorf("can't update policy in restricted table")
}
if p.strict && ps.Policy == nil {
return fmt.Errorf("can't degrade policy in restricted table")
}
if ps.Policy == nil {
ps.Policy = &policy.TLSPolicy{MXs: []string{}, Mode: ""}
}
_, err := p.conn.Exec(p.formQuery(
"INSERT INTO %[1]s(%[2]s) VALUES($1, $2, $3, $4, $5) "+
"ON CONFLICT (domain) DO UPDATE SET "+
"email=$2, mta_sts=$3, mxs=$4, mode=$5"),
ps.Name, ps.Email, ps.MTASTS,
strings.Join(ps.Policy.MXs[:], ","), ps.Policy.Mode)
return err
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DB whose policies should be validated-- all Pending policies.
func (db SQLDatabase) DomainsToValidate() ([]string, error) {
domains := []string{}
data, err := db.PendingPolicies.GetPolicies(false)
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// a particular domain in Pending.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
data, ok, err := db.PendingPolicies.GetPolicy(domain)
if !ok {
err = fmt.Errorf("domain %s not in database", domain)
}
if err != nil {
return []string{}, err
}
return data.Policy.MXs, nil
}
Loading