Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 29 additions & 47 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type checkPerformer func(API, string) (checker.DomainResult, error)
// Any POST request accepts either URL query parameters or data value parameters,
// and prefers the latter if both are present.
type API struct {
Database db.Database
Database *db.SQLDatabase
CheckDomain checkPerformer
List PolicyList
DontScan map[string]bool
Expand All @@ -61,7 +61,7 @@ type PolicyList interface {
type EmailSender interface {
// SendValidation sends a validation e-mail for a particular domain,
// with a particular validation token.
SendValidation(*models.Domain, string) error
SendValidation(*models.PolicySubmission, string) error
}

// APIResponse wraps all the responses from this API.
Expand Down Expand Up @@ -90,7 +90,7 @@ func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.
}

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := models.Domain{Name: domain}.AsyncPolicyListCheck(api.Database, api.List)
policyChan := models.PolicySubmission{Name: domain}.AsyncPolicyListCheck(api.Database.PendingPolicies, api.Database.Policies, api.List)
c := checker.Checker{
Cache: &checker.ScanCache{
ScanStore: api.Database,
Expand Down Expand Up @@ -171,47 +171,42 @@ func (api API) Scan(r *http.Request) APIResponse {
// MaxHostnames is the maximum number of hostnames that can be specified for a single domain's TLS policy.
const MaxHostnames = 8

// Extracts relevant parameters from http.Request for a POST to /api/queue
// TODO: also validate hostnames as FQDNs.
func getDomainParams(r *http.Request) (models.Domain, error) {
// Extracts relevant parameters from http.Request for a POST to /api/queue into PolicySubmission
// If MTASTS is set, doesn't try to extract hostnames. Otherwise, expects between 1 and MaxHostnames
// valid hostnames to be given in |r|.
func getDomainParams(r *http.Request) (models.PolicySubmission, error) {
name, err := getASCIIDomain(r)
if err != nil {
return models.Domain{}, err
return models.PolicySubmission{}, err
}
email, err := getParam("email", r)
if err != nil {
email = validationAddress(name)
}
mtasts := r.FormValue("mta-sts")
domain := models.Domain{
domain := models.PolicySubmission{
Name: name,
Email: email,
MTASTS: mtasts == "on",
State: models.StateUnconfirmed,
}
email, err := getParam("email", r)
if err == nil {
domain.Email = email
} else {
domain.Email = validationAddress(&domain)
}
queueWeeks, err := getInt("weeks", r, 4, 52, 4)
if err != nil {
return domain, err
}
domain.QueueWeeks = queueWeeks

if mtasts != "on" {
if !domain.MTASTS {
p := policy.TLSPolicy{Mode: "testing", MXs: make([]string, 0)}
for _, hostname := range r.PostForm["hostnames"] {
if len(hostname) == 0 {
continue
}
if !validDomainName(strings.TrimPrefix(hostname, ".")) {
return domain, fmt.Errorf("Hostname %s is invalid", hostname)
}
domain.MXs = append(domain.MXs, hostname)
p.MXs = append(p.MXs, hostname)
}
if len(domain.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", domain.Name)
if len(p.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", name)
}
if len(domain.MXs) > MaxHostnames {
if len(p.MXs) > MaxHostnames {
return domain, fmt.Errorf("No more than 8 MX hostnames are permitted")
}
domain.Policy = &p
}
return domain, nil
}
Expand All @@ -221,7 +216,7 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
// domain: Mail domain to queue a TLS policy for.
// mta_sts: "on" if domain supports MTA-STS, else "".
// hostnames: List of MX hostnames to put into this domain's TLS policy. Up to 8.
// Sets models.Domain object as response.
// Sets models.PolicySubmission object as response.
// weeks (optional, default 4): How many weeks is this domain queued for.
// email (optional): Contact email associated with domain.
// GET /api/queue?domain=<domain>
Expand All @@ -233,12 +228,14 @@ func (api API) Queue(r *http.Request) APIResponse {
if err != nil {
return badRequest(err.Error())
}
ok, msg, scan := domain.IsQueueable(api.Database, api.Database, api.List)
if !domain.CanUpdate(api.Database.Policies) {
return badRequest("existing submission can't be updated")
}
ok, msg := domain.HasValidScan(api.Database)
if !ok {
return badRequest(msg)
}
domain.PopulateFromScan(scan)
token, err := domain.InitializeWithToken(api.Database, api.Database)
token, err := domain.InitializeWithToken(api.Database.PendingPolicies, api.Database)
if err != nil {
return serverError(err.Error())
}
Expand All @@ -251,23 +248,8 @@ func (api API) Queue(r *http.Request) APIResponse {
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain.Name),
}
}
// GET: Retrieve domain status from queue
if r.Method == http.MethodGet {
domainName, err := getASCIIDomain(r)
if err != nil {
return badRequest(err.Error())
}
domainObj, err := models.GetDomain(api.Database, domainName)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{
StatusCode: http.StatusOK,
Response: domainObj,
}
}
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
Message: "/api/queue only accepts POST requests"}
}

// Validate handles requests to /api/validate
Expand All @@ -284,7 +266,7 @@ func (api API) Validate(r *http.Request) APIResponse {
Message: "/api/validate only accepts POST requests"}
}
tokenData := models.Token{Token: token}
domain, userErr, dbErr := tokenData.Redeem(api.Database, api.Database)
domain, userErr, dbErr := tokenData.Redeem(api.Database.PendingPolicies, api.Database.Policies, api.Database)
if userErr != nil {
return badRequest(userErr.Error())
}
Expand Down
1 change: 1 addition & 0 deletions db/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Database structure
8 changes: 0 additions & 8 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ type Database interface {
PutLocalStats(time.Time) (checker.AggregatedScan, error)
// Gets counts per day of hosts supporting MTA-STS for a given source.
GetStats(string) (stats.Series, error)
// Upserts domain state.
PutDomain(models.Domain) error
// Retrieves state of a domain
GetDomain(string, models.DomainState) (models.Domain, error)
// Retrieves all domains in a particular state.
GetDomains(models.DomainState) ([]models.Domain, error)
SetStatus(string, models.DomainState) error
RemoveDomain(string, models.DomainState) (models.Domain, error)
ClearTables() error
}

Expand Down
122 changes: 122 additions & 0 deletions db/policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package db

import (
"database/sql"
"fmt"
"strings"

"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
)

// PolicyDB is a database of PolicySubmissions.
type PolicyDB struct {
tableName string
conn *sql.DB
strict bool
}

func (p *PolicyDB) formQuery(query string) string {
return fmt.Sprintf(query, p.tableName, "domain, email, mta_sts, mxs, mode")
}

type scanner interface {
Scan(dest ...interface{}) error
}

func (p *PolicyDB) scanPolicy(result scanner) (models.PolicySubmission, error) {
data := models.PolicySubmission{Policy: new(policy.TLSPolicy)}
var rawMXs string
err := result.Scan(
&data.Name, &data.Email,
&data.MTASTS, &rawMXs, &data.Policy.Mode)
data.Policy.MXs = strings.Split(rawMXs, ",")
return data, err
}

// GetPolicies returns a list of policy submissions that match
// the mtasts status given.
func (p *PolicyDB) GetPolicies(mtasts bool) ([]models.PolicySubmission, error) {
rows, err := p.conn.Query(p.formQuery(
"SELECT %[2]s FROM %[1]s WHERE mta_sts=$1"), mtasts)
if err != nil {
return nil, err
}
defer rows.Close()
policies := []models.PolicySubmission{}
for rows.Next() {
policy, err := p.scanPolicy(rows)
if err != nil {
return nil, err
}
policies = append(policies, policy)
}
return policies, nil
}

// GetPolicy returns the policy submission for the given domain.
func (p *PolicyDB) GetPolicy(domainName string) (models.PolicySubmission, bool, error) {
row := p.conn.QueryRow(p.formQuery(
"SELECT %[2]s FROM %[1]s WHERE domain=$1"), domainName)
result, err := p.scanPolicy(row)
if err == sql.ErrNoRows {
return result, false, nil
}
return result, true, err
}

// RemovePolicy removes the policy submission with the given domain from
// the database.
func (p *PolicyDB) RemovePolicy(domainName string) (models.PolicySubmission, error) {
row := p.conn.QueryRow(p.formQuery(
"DELETE FROM %[1]s WHERE domain=$1 RETURNING %[2]s"), domainName)
return p.scanPolicy(row)
}

// PutOrUpdatePolicy upserts the given policy into the data store, if
// CanUpdate passes.
func (p *PolicyDB) PutOrUpdatePolicy(ps *models.PolicySubmission) error {
if p.strict && !ps.CanUpdate(p) {
return fmt.Errorf("can't update policy in restricted table")
}
if p.strict && ps.Policy == nil {
return fmt.Errorf("can't degrade policy in restricted table")
}
if ps.Policy == nil {
ps.Policy = &policy.TLSPolicy{MXs: []string{}, Mode: ""}
}
_, err := p.conn.Exec(p.formQuery(
"INSERT INTO %[1]s(%[2]s) VALUES($1, $2, $3, $4, $5) "+
"ON CONFLICT (domain) DO UPDATE SET "+
"email=$2, mta_sts=$3, mxs=$4, mode=$5"),
ps.Name, ps.Email, ps.MTASTS,
strings.Join(ps.Policy.MXs[:], ","), ps.Policy.Mode)
return err
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DB whose policies should be validated-- all Pending policies.
func (p *PolicyDB) DomainsToValidate() ([]string, error) {
domains := []string{}
data, err := p.GetPolicies(true)
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// a particular domain in Pending.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
data, ok, err := db.PendingPolicies.GetPolicy(domain)
if !ok {
err = fmt.Errorf("domain %s not in database", domain)
}
if err != nil {
return []string{}, err
}
return data.Policy.MXs, nil
}
19 changes: 19 additions & 0 deletions db/scripts/init_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ CREATE TABLE IF NOT EXISTS blacklisted_emails
timestamp TIMESTAMP
);

CREATE TABLE IF NOT EXISTS pending_policies
(
domain TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
mta_sts BOOLEAN DEFAULT FALSE,
mxs TEXT NOT NULL,
mode VARCHAR(255) NOT NULL
);


CREATE TABLE IF NOT EXISTS policies
(
domain TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
mta_sts BOOLEAN DEFAULT FALSE,
mxs TEXT NOT NULL,
mode VARCHAR(255) NOT NULL
);

-- Schema change: add "last_updated" timestamp column if it doesn't exist.

ALTER TABLE domains ADD COLUMN IF NOT EXISTS last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
Expand Down
Loading