Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainR
}
}
result.PreferredHostnames = checkedHostnames
result.MTASTSResult = c.checkMTASTS(domain, result.HostnameResults)
result.MTASTSResult = c.CheckMTASTS(domain, result.HostnameResults)

// Derive Domain code from Hostname results.
if len(checkedHostnames) == 0 {
Expand Down
4 changes: 3 additions & 1 deletion checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
}
}

func (c Checker) checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
// CheckMTASTS performs all associated checks for a particular domain's
// MTA-STS support.
func (c Checker) CheckMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
if c.checkMTASTSOverride != nil {
// Allow the Checker to mock this function.
return c.checkMTASTSOverride(domain, hostnameResults)
Expand Down
27 changes: 16 additions & 11 deletions db/sqldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ func (db *SQLDatabase) PutDomain(domain models.Domain) error {
return err
}

// UpdateDomainPolicy allows us to update the internal data about a particular domain.
func (db *SQLDatabase) UpdateDomainPolicy(domain models.Domain) error {
_, err := db.conn.Exec("UPDATE domains SET data=$2, status=$3 WHERE domain=$1 AND mta_sts=TRUE",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will update the policy and status for all matching domains - there could but multiple (domain, status) pairs returned. Seems like updating status for all of them could cause a collision.

domain.Name, strings.Join(domain.MXs[:], ","), domain.State)
return err

}

// GetDomain retrieves the status and information associated with a particular
// mailserver domain.
func (db SQLDatabase) GetDomain(domain string) (models.Domain, error) {
Expand All @@ -234,7 +242,7 @@ func (db SQLDatabase) GetDomain(domain string) (models.Domain, error) {
// GetDomains retrieves all the domains which match a particular state,
// that are not in MTA_STS mode
func (db SQLDatabase) GetDomains(state models.DomainState) ([]models.Domain, error) {
return db.getDomainsWhere("status=$1", state)
return db.getDomainsWhere("status=$1 AND mta_sts=FALSE", state)
}

// GetMTASTSDomains retrieves domains which wish their policy to be queued with their MTASTS.
Expand Down Expand Up @@ -312,20 +320,17 @@ func (db SQLDatabase) DomainsToValidate() ([]string, error) {
if err != nil {
return domains, err
}
dataMTASTS, err := db.GetMTASTSDomains()
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// a particular domain.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
data, err := db.GetDomain(domain)
if err != nil {
return []string{}, err
for _, domainInfo := range dataMTASTS {
domains = append(domains, domainInfo.Name)
}
return data.MXs, nil
return domains, nil
}

// GetHostnameScan retrives most recent scan from database.
Expand Down
42 changes: 22 additions & 20 deletions db/sqldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,6 @@ func TestDomainsToValidate(t *testing.T) {
}
}

func TestHostnamesForDomain(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "x", MXs: []string{"x.com", "y.org"}})
database.PutDomain(models.Domain{Name: "y"})
result, err := database.HostnamesForDomain("x")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) != 2 || result[0] != "x.com" || result[1] != "y.org" {
t.Errorf("Expected two hostnames, x.com and y.org\n")
}
result, err = database.HostnamesForDomain("y")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) > 0 {
t.Errorf("Expected no hostnames to be returned, got %s\n", result[0])
}
}

func TestPutAndIsBlacklistedEmail(t *testing.T) {
defer database.ClearTables()

Expand Down Expand Up @@ -447,3 +427,25 @@ func TestGetMTASTSDomains(t *testing.T) {
}
}
}

func TestUpdateDomainPolicy(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "no-mtasts"})
database.PutDomain(models.Domain{Name: "mtasts", MTASTSMode: "on", Email: "real-email"})
database.UpdateDomainPolicy(models.Domain{Name: "no-mtasts", State: models.StateEnforce})
database.UpdateDomainPolicy(models.Domain{Name: "mtasts", State: models.StateEnforce, MXs: []string{"hostname"}, Email: "fake-email"})
domain, _ := database.GetDomain("no-mtasts")
if domain.State == models.StateEnforce {
t.Errorf("Expected State to not update since unicorns isn't MTASTS")
}
domain, _ = database.GetDomain("mtasts")
if domain.State != models.StateEnforce {
t.Errorf("Expected State to update after UpdateDomainPolicy")
}
if len(domain.MXs) != 1 || domain.MXs[0] != "hostname" {
t.Errorf("Expected MXs to update after UpdateDomainPolicy")
}
if domain.Email != "real-email" {
t.Errorf("Did not expect Email to update after UpdateDomainPolicy")
}
}
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ func main() {
}
if os.Getenv("VALIDATE_QUEUED") == "1" {
log.Println("[Starting queued validator]")
go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
v := validator.Validator{
Name: "Testing and enforced domains",
Store: db,
Interval: 24 * time.Hour,
CheckPerformer: validator.GetDBCheck(db.UpdateDomainPolicy),
}
go v.Run()
// go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
}
ServePublicEndpoints(&api, &cfg)
}
12 changes: 12 additions & 0 deletions models/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/util"
)

// Domain stores the preload state of a single domain.
Expand Down Expand Up @@ -128,3 +129,14 @@ func (d Domain) AsyncPolicyListCheck(store domainStore, list policyList) <-chan
go func() { result <- *d.PolicyListCheck(store, list) }()
return result
}

// SamePolicy checks whether the underlying policy represented by Domain
// and the one picked up by the MTA-STS check represent the same policy.
func (d *Domain) SamePolicy(result *checker.MTASTSResult) bool {
if (result.Mode == "enforce" && d.State != StateEnforce) ||
(result.Mode == "testing" && d.State != StateTesting) ||
result.Mode == "none" {
return false
}
return util.ListsEqual(d.MXs, result.MXs)
}
19 changes: 15 additions & 4 deletions policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/http"
"sync"
"time"

"github.com/EFForg/starttls-backend/models"
)

// policyURL is the default URL from which to fetch the policy JSON.
Expand Down Expand Up @@ -80,14 +82,23 @@ func (l *UpdatedList) DomainsToValidate() ([]string, error) {
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// GetDomain [interface Validator] retrieves the domain object for
// a particular domain.
func (l *UpdatedList) HostnamesForDomain(domain string) ([]string, error) {
func (l *UpdatedList) GetDomain(domain string) (models.Domain, error) {
policy, err := l.Get(domain)
if err != nil {
return []string{}, err
return models.Domain{}, err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this reuse of the domain model!

}
domainObj := models.Domain{
Name: domain,
MXs: policy.MXs,
}
if policy.Mode == "enforce" {
domainObj.State = models.StateEnforce
} else if policy.Mode == "testing" {
domainObj.State = models.StateTesting
}
return policy.MXs, nil
return domainObj, nil
}

// Get safely reads from the underlying policy list and returns a TLSPolicy for a domain
Expand Down
6 changes: 3 additions & 3 deletions policy/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func TestHostnamesForDomain(t *testing.T) {
var updatedList = List{Policies: map[string]TLSPolicy{
"eff.org": TLSPolicy{MXs: hostnames}}}
list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second)
returned, err := list.HostnamesForDomain("eff.org")
returned, err := list.GetDomain("eff.org")
if err != nil {
t.Fatalf("Encountered %v", err)
}
if !reflect.DeepEqual(returned, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned)
if !reflect.DeepEqual(returned.MXs, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned.MXs)
}
}

Expand Down
27 changes: 27 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package util

import (
"reflect"
)

// ListsEqual checks that two lists have the same elements,
// regardless of order.
func ListsEqual(x []string, y []string) bool {
// Transform each list into a histogram
xMap := make(map[string]uint)
yMap := make(map[string]uint)
for _, element := range x {
if _, ok := xMap[element]; !ok {
xMap[element] = 0
}
xMap[element]++
}
for _, element := range y {
if _, ok := yMap[element]; !ok {
yMap[element] = 0
}
yMap[element]++
}
// Compare the histogram maps
return reflect.DeepEqual(xMap, yMap)
}
29 changes: 29 additions & 0 deletions util/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package util

import (
"testing"
)

func TestListsEqual(t *testing.T) {
testCases := []struct {
x []string
y []string
expected bool
}{
{[]string{}, []string{}, true},
{[]string{"a"}, []string{}, false},
{[]string{"a"}, []string{"a"}, true},
{[]string{"a", "a"}, []string{"a"}, false},
{[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true},
{[]string{"b", "a", "c"}, []string{"a", "b", "c"}, true},
{[]string{"b", "a", "b", "c"}, []string{"a", "b", "c"}, false},
{[]string{"b", "a", "b", "c"}, []string{"a", "b", "a", "c"}, false},
{[]string{"a", "a", "b", "c"}, []string{"a", "b", "a", "c"}, true},
}
for _, testCase := range testCases {
got := ListsEqual(testCase.x, testCase.y)
if got != testCase.expected {
t.Errorf("Compared %v and %v, expected %v, got %v", testCase.x, testCase.y, testCase.expected, got)
}
}
}
59 changes: 43 additions & 16 deletions validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/models"
"github.com/getsentry/raven-go"
)

Expand All @@ -14,7 +15,7 @@ import (
// expected hostnames).
type DomainPolicyStore interface {
DomainsToValidate() ([]string, error)
HostnamesForDomain(string) ([]string, error)
GetDomain(string) (models.Domain, error)
}

// Called with failure by defaault.
Expand All @@ -28,8 +29,10 @@ func reportToSentry(name string, domain string, result checker.DomainResult) {
result)
}

type checkPerformer func(string, []string) checker.DomainResult
type resultCallback func(string, string, checker.DomainResult)
type resultCallback func(string, models.Domain, checker.DomainResult)

// CheckPerformer defines a function that performs a security check on a domain.
type CheckPerformer func(models.Domain) checker.DomainResult

// Validator runs checks regularly against domain policies. This structure
// defines the configurations.
Expand All @@ -47,18 +50,42 @@ type Validator struct {
OnFailure resultCallback
// OnSuccess: optional. Called when a particular policy validation succeeds.
OnSuccess resultCallback
// checkPerformer: performs the check.
checkPerformer checkPerformer
// CheckPerformer: performs the check.
CheckPerformer CheckPerformer
}

// UpdatePolicy is a callback we can provide to GetDBCheck in order to perform a policy
// update if we notice a discrepancy between our view and the MTA-STS policy.
type UpdatePolicy func(models.Domain) error

// GetDBCheck returns a CheckPerformer that performs an MTASTS check and update if
// the policy is updated, or performs a regular security check if MTASTS is not supported.
func GetDBCheck(update UpdatePolicy) CheckPerformer {
c := checker.Checker{Cache: checker.MakeSimpleCache(time.Hour)}
return func(domain models.Domain) checker.DomainResult {
if domain.MTASTSMode == "on" {
result := c.CheckDomain(domain.Name, []string{})
if !domain.SamePolicy(result.MTASTSResult) {
if update(domain) != nil {
reportToSentry("Couldn't update policy in DB", domain.Name, result)
}
}
return result
}
return c.CheckDomain(domain.Name, domain.MXs)
}
}

func (v *Validator) checkPolicy(domain string, hostnames []string) checker.DomainResult {
if v.checkPerformer == nil {
func (v *Validator) checkPolicy(domain models.Domain) checker.DomainResult {
if v.CheckPerformer == nil {
c := checker.Checker{
Cache: checker.MakeSimpleCache(time.Hour),
}
v.checkPerformer = c.CheckDomain
v.CheckPerformer = func(domain models.Domain) checker.DomainResult {
return c.CheckDomain(domain.Name, domain.MXs)
}
}
return v.checkPerformer(domain, hostnames)
return v.CheckPerformer(domain)
}

func (v *Validator) interval() time.Duration {
Expand All @@ -68,14 +95,14 @@ func (v *Validator) interval() time.Duration {
return time.Hour * 24
}

func (v *Validator) policyFailed(name string, domain string, result checker.DomainResult) {
func (v *Validator) policyFailed(name string, domain models.Domain, result checker.DomainResult) {
if v.OnFailure != nil {
v.OnFailure(name, domain, result)
}
reportToSentry(name, domain, result)
reportToSentry(name, domain.Name, result)
}

func (v *Validator) policyPassed(name string, domain string, result checker.DomainResult) {
func (v *Validator) policyPassed(name string, domain models.Domain, result checker.DomainResult) {
if v.OnSuccess != nil {
v.OnSuccess(name, domain, result)
}
Expand All @@ -93,17 +120,17 @@ func (v *Validator) Run() {
continue
}
for _, domain := range domains {
hostnames, err := v.Store.HostnamesForDomain(domain)
domainData, err := v.Store.GetDomain(domain)
if err != nil {
log.Printf("[%s validator] Could not retrieve policy for domain %s: %v", v.Name, domain, err)
continue
}
result := v.checkPolicy(domain, hostnames)
result := v.checkPolicy(domainData)
if result.Status != 0 {
log.Printf("[%s validator] %s failed; sending report", v.Name, domain)
v.policyFailed(v.Name, domain, result)
v.policyFailed(v.Name, domainData, result)
} else {
v.policyPassed(v.Name, domain, result)
v.policyPassed(v.Name, domainData, result)
}
}
}
Expand Down
Loading