diff --git a/db/policy.go b/db/policy.go index 660e2f2b..f8b6f128 100644 --- a/db/policy.go +++ b/db/policy.go @@ -97,9 +97,9 @@ func (p *PolicyDB) PutOrUpdatePolicy(ps *models.PolicySubmission) error { // DomainsToValidate [interface Validator] retrieves domains from the // DB whose policies should be validated-- all Pending policies. -func (db SQLDatabase) DomainsToValidate() ([]string, error) { +func (p *PolicyDB) DomainsToValidate() ([]string, error) { domains := []string{} - data, err := db.PendingPolicies.GetPolicies(false) + data, err := p.GetPolicies(true) if err != nil { return domains, err } diff --git a/db/sqldb_test.go b/db/sqldb_test.go index 7d669865..8e6ca095 100644 --- a/db/sqldb_test.go +++ b/db/sqldb_test.go @@ -235,29 +235,6 @@ func TestPutTokenTwice(t *testing.T) { } } -func TestDomainsToValidate(t *testing.T) { - database.ClearTables() - mtastsMap := map[string]bool{ - "a": false, "b": true, "c": false, "d": true, - } - for domain, mtasts := range mtastsMap { - if mtasts { - database.Policies.PutOrUpdatePolicy(&models.PolicySubmission{Name: domain, MTASTS: true}) - } else { - database.Policies.PutOrUpdatePolicy(&models.PolicySubmission{Name: domain}) - } - } - result, err := database.DomainsToValidate() - if err != nil { - t.Fatalf("DomainsToValidate failed: %v\n", err) - } - for _, domain := range result { - if !mtastsMap[domain] { - t.Errorf("Did not expect %s to be returned", domain) - } - } -} - func TestHostnamesForDomain(t *testing.T) { database.ClearTables() database.PendingPolicies.PutOrUpdatePolicy(&models.PolicySubmission{Name: "x", diff --git a/main.go b/main.go index 537a7f88..08e723f4 100644 --- a/main.go +++ b/main.go @@ -98,14 +98,21 @@ func main() { Emailer: emailConfig, } a.ParseTemplates() - if os.Getenv("VALIDATE_LIST") == "1" { - log.Println("[Starting list validator]") - go validator.ValidateRegularly("Live policy list", list, 24*time.Hour) - } + // if os.Getenv("VALIDATE_LIST") == "1" { + // log.Println("[Starting list validator]") + // go validator.ValidateRegularly("Live policy list", list, 24*time.Hour) + // } if os.Getenv("VALIDATE_QUEUED") == "1" { - log.Println("[Starting queued validator]") - go validator.ValidateRegularly("Testing domains", db, 24*time.Hour) + v := validator.Validator{ + Name: "testing and enforced domains", + Store: db.Policies, + Interval: 24 * time.Hour, + } + go v.Run() + // log.Println("[Starting queued validator]") + // go validator.ValidateRegularly("MTA-STS domains", db.Policies, 24*time.Hour) } + // go validator.ValidateRegularly("MTA-STS domains", db.Policies, 24*time.Hour) go stats.UpdateRegularly(db, time.Hour) ServePublicEndpoints(&a, &cfg) } diff --git a/validator/validator.go b/validator/validator.go index 87d0b606..ce270df0 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -6,6 +6,8 @@ import ( "time" "github.com/EFForg/starttls-backend/checker" + "github.com/EFForg/starttls-backend/models" + "github.com/EFForg/starttls-backend/policy" "github.com/getsentry/raven-go" ) @@ -14,7 +16,7 @@ import ( // expected hostnames). type DomainPolicyStore interface { DomainsToValidate() ([]string, error) - HostnamesForDomain(string) ([]string, error) + GetPolicy(string) (models.PolicySubmission, bool, error) } // Called with failure by defaault. @@ -28,7 +30,7 @@ func reportToSentry(name string, domain string, result checker.DomainResult) { result) } -type checkPerformer func(string, []string) checker.DomainResult +type checkPerformer func(models.PolicySubmission) checker.DomainResult type resultCallback func(string, string, checker.DomainResult) // Validator runs checks regularly against domain policies. This structure @@ -47,18 +49,37 @@ type Validator struct { OnFailure resultCallback // OnSuccess: optional. Called when a particular policy validation succeeds. OnSuccess resultCallback - // checkPerformer: performs the check. - checkPerformer checkPerformer + // CheckPerformer: performs the check. + CheckPerformer checkPerformer } -func (v *Validator) checkPolicy(domain string, hostnames []string) checker.DomainResult { - if v.checkPerformer == nil { - c := checker.Checker{ - Cache: checker.MakeSimpleCache(time.Hour), +func resultMTASTSToPolicy(r *checker.MTASTSResult) *policy.TLSPolicy { + return &policy.TLSPolicy{Mode: r.Mode, MXs: r.MXs} +} + +func getMTASTSUpdater(update func(*models.PolicySubmission) error) checkPerformer { + c := checker.Checker{Cache: checker.MakeSimpleCache(time.Hour)} + return func(p models.PolicySubmission) checker.DomainResult { + if p.MTASTS { + result := c.CheckDomain(p.Name, []string{}) + if !p.Policy.Equals(resultMTASTSToPolicy(result.MTASTSResult)) { + if err := update(&p); err != nil { + reportToSentry(fmt.Sprintf("couldn't update policy in DB: %v", err), p.Name, result) + } + } + } + return c.CheckDomain(p.Name, p.Policy.MXs) + } +} + +func (v *Validator) checkPolicy(p *models.PolicySubmission) checker.DomainResult { + if v.CheckPerformer == nil { + c := checker.Checker{Cache: checker.MakeSimpleCache(time.Hour)} + v.CheckPerformer = func(policy models.PolicySubmission) checker.DomainResult { + return c.CheckDomain(p.Name, p.Policy.MXs) } - v.checkPerformer = c.CheckDomain } - return v.checkPerformer(domain, hostnames) + return v.CheckPerformer(*p) } func (v *Validator) interval() time.Duration { @@ -93,12 +114,12 @@ func (v *Validator) Run() { continue } for _, domain := range domains { - hostnames, err := v.Store.HostnamesForDomain(domain) - if err != nil { + policy, ok, err := v.Store.GetPolicy(domain) + if err != nil || !ok { log.Printf("[%s validator] Could not retrieve policy for domain %s: %v", v.Name, domain, err) continue } - result := v.checkPolicy(domain, hostnames) + result := v.checkPolicy(&policy) if result.Status != 0 { log.Printf("[%s validator] %s failed; sending report", v.Name, domain) v.policyFailed(v.Name, domain, result) @@ -108,15 +129,3 @@ func (v *Validator) Run() { } } } - -// ValidateRegularly regularly runs checker.CheckDomain against a Domain- -// Hostname map. Interval specifies the interval to wait between each run. -// Failures are reported to Sentry. -func ValidateRegularly(name string, store DomainPolicyStore, interval time.Duration) { - v := Validator{ - Name: name, - Store: store, - Interval: interval, - } - v.Run() -} diff --git a/validator/validator_test.go b/validator/validator_test.go index eec55a9f..8da4084f 100644 --- a/validator/validator_test.go +++ b/validator/validator_test.go @@ -5,6 +5,8 @@ import ( "time" "github.com/EFForg/starttls-backend/checker" + "github.com/EFForg/starttls-backend/models" + "github.com/EFForg/starttls-backend/policy" ) type mockDomainPolicyStore struct { @@ -19,21 +21,21 @@ func (m mockDomainPolicyStore) DomainsToValidate() ([]string, error) { return domains, nil } -func (m mockDomainPolicyStore) HostnamesForDomain(domain string) ([]string, error) { - return m.hostnames[domain], nil +func (m mockDomainPolicyStore) GetPolicy(domain string) (models.PolicySubmission, bool, error) { + return models.PolicySubmission{Name: domain, Policy: &policy.TLSPolicy{Mode: "testing", MXs: m.hostnames[domain]}}, true, nil } func noop(_ string, _ string, _ checker.DomainResult) {} func TestRegularValidationValidates(t *testing.T) { called := make(chan bool) - fakeChecker := func(domain string, hostnames []string) checker.DomainResult { + fakeChecker := func(_ models.PolicySubmission) checker.DomainResult { called <- true return checker.DomainResult{} } mock := mockDomainPolicyStore{ hostnames: map[string][]string{"a": []string{"hostname"}}} - v := Validator{Store: mock, Interval: 100 * time.Millisecond, checkPerformer: fakeChecker, OnFailure: noop} + v := Validator{Store: mock, Interval: 100 * time.Millisecond, CheckPerformer: fakeChecker, OnFailure: noop} go v.Run() select { @@ -46,8 +48,8 @@ func TestRegularValidationValidates(t *testing.T) { func TestRegularValidationReportsErrors(t *testing.T) { reports := make(chan string) - fakeChecker := func(domain string, hostnames []string) checker.DomainResult { - if domain == "fail" || domain == "error" { + fakeChecker := func(p models.PolicySubmission) checker.DomainResult { + if p.Name == "fail" || p.Name == "error" { return checker.DomainResult{Status: 5} } return checker.DomainResult{Status: 0} @@ -64,7 +66,7 @@ func TestRegularValidationReportsErrors(t *testing.T) { "fail": []string{"hostname"}, "error": []string{"hostname"}, "normal": []string{"hostname"}}} - v := Validator{Store: mock, Interval: 100 * time.Millisecond, checkPerformer: fakeChecker, + v := Validator{Store: mock, Interval: 100 * time.Millisecond, CheckPerformer: fakeChecker, OnFailure: fakeReporter, OnSuccess: fakeSuccessReporter, } go v.Run()