Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainR
}
}
result.PreferredHostnames = checkedHostnames
result.MTASTSResult = c.checkMTASTS(domain, result.HostnameResults)
result.MTASTSResult = c.CheckMTASTS(domain, result.HostnameResults)

// Derive Domain code from Hostname results.
if len(checkedHostnames) == 0 {
Expand Down
4 changes: 3 additions & 1 deletion checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
}
}

func (c Checker) checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
// CheckMTASTS performs all associated checks for a particular domain's
// MTA-STS support.
func (c Checker) CheckMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
if c.checkMTASTSOverride != nil {
// Allow the Checker to mock this function.
return c.checkMTASTSOverride(domain, hostnameResults)
Expand Down
28 changes: 20 additions & 8 deletions db/sqldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ func (db *SQLDatabase) PutDomain(domain models.Domain) error {
return err
}

// UpdateDomainPolicy allows us to update the internal data about a particular domain.
func (db *SQLDatabase) UpdateDomainPolicy(domain models.Domain) error {
_, err := db.conn.Exec("UPDATE domains SET data=$2, status=$3 WHERE domain=$1 AND mta_sts=TRUE",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will update the policy and status for all matching domains - there could but multiple (domain, status) pairs returned. Seems like updating status for all of them could cause a collision.

domain.Name, strings.Join(domain.MXs[:], ","), domain.State)
return err

}

// GetDomain retrieves the status and information associated with a particular
// mailserver domain.
func (db SQLDatabase) GetDomain(domain string, state models.DomainState) (models.Domain, error) {
Expand All @@ -225,7 +233,7 @@ func (db SQLDatabase) GetDomain(domain string, state models.DomainState) (models
// GetDomains retrieves all the domains which match a particular state,
// that are not in MTA_STS mode
func (db SQLDatabase) GetDomains(state models.DomainState) ([]models.Domain, error) {
return db.queryDomainsWhere("status=$1", state)
return db.queryDomainsWhere("status=$1 AND mta_sts=FALSE", state)
}

// GetMTASTSDomains retrieves domains which wish their policy to be queued with their MTASTS.
Expand Down Expand Up @@ -324,31 +332,35 @@ func (db SQLDatabase) queryDomainsWhere(condition string, args ...interface{}) (
return domains, nil
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DomainsToValidate [interface DomainPolicyStore] retrieves domains from the
// DB whose policies should be validated.
func (db SQLDatabase) DomainsToValidate() ([]string, error) {
domains := []string{}
data, err := db.GetDomains(models.StateTesting)
if err != nil {
return domains, err
}
dataMTASTS, err := db.GetMTASTSDomains()
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
for _, domainInfo := range dataMTASTS {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// GetDomainPolicy [interface DomainPolicyStore] retrieves the domain object for
// a particular domain.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
func (db SQLDatabase) GetDomainPolicy(domain string) (models.Domain, error) {
data, err := db.GetDomain(domain, models.StateEnforce)
if err != nil {
data, err = db.GetDomain(domain, models.StateTesting)
}
if err != nil {
return []string{}, err
}
return data.MXs, nil
return data, err
}

// GetHostnameScan retrives most recent scan from database.
Expand Down
44 changes: 22 additions & 22 deletions db/sqldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,28 +270,6 @@ func TestDomainsToValidate(t *testing.T) {
}
}

func TestHostnamesForDomain(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "x", MXs: []string{"x.com", "y.org"}})
database.PutDomain(models.Domain{Name: "y"})
database.SetStatus("x", models.StateTesting)
database.SetStatus("y", models.StateTesting)
result, err := database.HostnamesForDomain("x")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) != 2 || result[0] != "x.com" || result[1] != "y.org" {
t.Errorf("Expected two hostnames, x.com and y.org\n")
}
result, err = database.HostnamesForDomain("y")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) > 0 {
t.Errorf("Expected no hostnames to be returned, got %s\n", result[0])
}
}

func TestPutAndIsBlacklistedEmail(t *testing.T) {
defer database.ClearTables()

Expand Down Expand Up @@ -454,3 +432,25 @@ func TestGetMTASTSDomains(t *testing.T) {
}
}
}

func TestUpdateDomainPolicy(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "no-mtasts"})
database.PutDomain(models.Domain{Name: "mtasts", MTASTS: true, Email: "real-email"})
database.UpdateDomainPolicy(models.Domain{Name: "no-mtasts", State: models.StateEnforce})
database.UpdateDomainPolicy(models.Domain{Name: "mtasts", State: models.StateEnforce, MXs: []string{"hostname"}, Email: "fake-email"})
domain, _ := database.GetDomainPolicy("no-mtasts")
if domain.State == models.StateEnforce {
t.Errorf("Expected State to not update since unicorns isn't MTASTS")
}
domain, _ = database.GetDomainPolicy("mtasts")
if domain.State != models.StateEnforce {
t.Errorf("Expected State to update after UpdateDomainPolicy")
}
if len(domain.MXs) != 1 || domain.MXs[0] != "hostname" {
t.Errorf("Expected MXs to update after UpdateDomainPolicy")
}
if domain.Email != "real-email" {
t.Errorf("Did not expect Email to update after UpdateDomainPolicy")
}
}
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ func main() {
}
if os.Getenv("VALIDATE_QUEUED") == "1" {
log.Println("[Starting queued validator]")
go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
v := validator.Validator{
Name: "Testing and enforced domains",
Store: db,
Interval: 24 * time.Hour,
CheckPerformer: validator.GetDBCheck(db.UpdateDomainPolicy),
}
go v.Run()
// go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
}
ServePublicEndpoints(&api, &cfg)
}
12 changes: 12 additions & 0 deletions models/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/util"
)

/* Domain represents an email domain's TLS policy.
Expand Down Expand Up @@ -142,6 +143,17 @@ func (d Domain) AsyncPolicyListCheck(store domainStore, list policyList) <-chan
return result
}

// SamePolicy checks whether the underlying policy represented by Domain
// and the one picked up by the MTA-STS check represent the same policy.
func (d *Domain) SamePolicy(result *checker.MTASTSResult) bool {
if (result.Mode == "enforce" && d.State != StateEnforce) ||
(result.Mode == "testing" && d.State != StateTesting) ||
result.Mode == "none" {
return false
}
return util.ListsEqual(d.MXs, result.MXs)
}

// GetDomain retrieves Domain with the most "important" state.
// At any given time, there can only be one domain that's either StateEnforce
// or StateTesting. If that domain exists in the store, return that one.
Expand Down
21 changes: 16 additions & 5 deletions policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/http"
"sync"
"time"

"github.com/EFForg/starttls-backend/models"
)

// policyURL is the default URL from which to fetch the policy JSON.
Expand Down Expand Up @@ -68,7 +70,7 @@ type UpdatedList struct {
*List
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DomainsToValidate [interface DomainPolicyStore] retrieves domains from the
// DB whose policies should be validated.
func (l *UpdatedList) DomainsToValidate() ([]string, error) {
l.mu.RLock()
Expand All @@ -80,14 +82,23 @@ func (l *UpdatedList) DomainsToValidate() ([]string, error) {
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// GetDomainPolicy [interface DomainPolicyStore] retrieves the domain object for
// a particular domain.
func (l *UpdatedList) HostnamesForDomain(domain string) ([]string, error) {
func (l *UpdatedList) GetDomainPolicy(domain string) (models.Domain, error) {
policy, err := l.Get(domain)
if err != nil {
return []string{}, err
return models.Domain{}, err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this reuse of the domain model!

}
domainObj := models.Domain{
Name: domain,
MXs: policy.MXs,
}
if policy.Mode == "enforce" {
domainObj.State = models.StateEnforce
} else if policy.Mode == "testing" {
domainObj.State = models.StateTesting
}
return policy.MXs, nil
return domainObj, nil
}

// Get safely reads from the underlying policy list and returns a TLSPolicy for a domain
Expand Down
6 changes: 3 additions & 3 deletions policy/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func TestHostnamesForDomain(t *testing.T) {
var updatedList = List{Policies: map[string]TLSPolicy{
"eff.org": TLSPolicy{MXs: hostnames}}}
list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second)
returned, err := list.HostnamesForDomain("eff.org")
returned, err := list.GetDomainPolicy("eff.org")
if err != nil {
t.Fatalf("Encountered %v", err)
}
if !reflect.DeepEqual(returned, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned)
if !reflect.DeepEqual(returned.MXs, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned.MXs)
}
}

Expand Down
27 changes: 27 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package util

import (
"reflect"
)

// ListsEqual checks that two lists have the same elements,
// regardless of order.
func ListsEqual(x []string, y []string) bool {
// Transform each list into a histogram
xMap := make(map[string]uint)
yMap := make(map[string]uint)
for _, element := range x {
if _, ok := xMap[element]; !ok {
xMap[element] = 0
}
xMap[element]++
}
for _, element := range y {
if _, ok := yMap[element]; !ok {
yMap[element] = 0
}
yMap[element]++
}
// Compare the histogram maps
return reflect.DeepEqual(xMap, yMap)
}
29 changes: 29 additions & 0 deletions util/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package util

import (
"testing"
)

func TestListsEqual(t *testing.T) {
testCases := []struct {
x []string
y []string
expected bool
}{
{[]string{}, []string{}, true},
{[]string{"a"}, []string{}, false},
{[]string{"a"}, []string{"a"}, true},
{[]string{"a", "a"}, []string{"a"}, false},
{[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true},
{[]string{"b", "a", "c"}, []string{"a", "b", "c"}, true},
{[]string{"b", "a", "b", "c"}, []string{"a", "b", "c"}, false},
{[]string{"b", "a", "b", "c"}, []string{"a", "b", "a", "c"}, false},
{[]string{"a", "a", "b", "c"}, []string{"a", "b", "a", "c"}, true},
}
for _, testCase := range testCases {
got := ListsEqual(testCase.x, testCase.y)
if got != testCase.expected {
t.Errorf("Compared %v and %v, expected %v, got %v", testCase.x, testCase.y, testCase.expected, got)
}
}
}
Loading