Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions db/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ func (p *PolicyDB) PutOrUpdatePolicy(ps *models.PolicySubmission) error {

// DomainsToValidate [interface Validator] retrieves domains from the
// DB whose policies should be validated-- all Pending policies.
func (db SQLDatabase) DomainsToValidate() ([]string, error) {
func (p *PolicyDB) DomainsToValidate() ([]string, error) {
domains := []string{}
data, err := db.PendingPolicies.GetPolicies(false)
data, err := p.GetPolicies(true)
if err != nil {
return domains, err
}
Expand Down
23 changes: 0 additions & 23 deletions db/sqldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,29 +235,6 @@ func TestPutTokenTwice(t *testing.T) {
}
}

func TestDomainsToValidate(t *testing.T) {
database.ClearTables()
mtastsMap := map[string]bool{
"a": false, "b": true, "c": false, "d": true,
}
for domain, mtasts := range mtastsMap {
if mtasts {
database.Policies.PutOrUpdatePolicy(&models.PolicySubmission{Name: domain, MTASTS: true})
} else {
database.Policies.PutOrUpdatePolicy(&models.PolicySubmission{Name: domain})
}
}
result, err := database.DomainsToValidate()
if err != nil {
t.Fatalf("DomainsToValidate failed: %v\n", err)
}
for _, domain := range result {
if !mtastsMap[domain] {
t.Errorf("Did not expect %s to be returned", domain)
}
}
}

func TestHostnamesForDomain(t *testing.T) {
database.ClearTables()
database.PendingPolicies.PutOrUpdatePolicy(&models.PolicySubmission{Name: "x",
Expand Down
19 changes: 13 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,21 @@ func main() {
Emailer: emailConfig,
}
a.ParseTemplates()
if os.Getenv("VALIDATE_LIST") == "1" {
log.Println("[Starting list validator]")
go validator.ValidateRegularly("Live policy list", list, 24*time.Hour)
}
// if os.Getenv("VALIDATE_LIST") == "1" {
// log.Println("[Starting list validator]")
// go validator.ValidateRegularly("Live policy list", list, 24*time.Hour)
// }
if os.Getenv("VALIDATE_QUEUED") == "1" {
log.Println("[Starting queued validator]")
go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
v := validator.Validator{
Name: "testing and enforced domains",
Store: db.Policies,
Interval: 24 * time.Hour,
}
go v.Run()
// log.Println("[Starting queued validator]")
// go validator.ValidateRegularly("MTA-STS domains", db.Policies, 24*time.Hour)
}
// go validator.ValidateRegularly("MTA-STS domains", db.Policies, 24*time.Hour)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove leftover & commented code

go stats.UpdateRegularly(db, time.Hour)
ServePublicEndpoints(&a, &cfg)
}
59 changes: 34 additions & 25 deletions validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
"github.com/getsentry/raven-go"
)

Expand All @@ -14,7 +16,7 @@ import (
// expected hostnames).
type DomainPolicyStore interface {
DomainsToValidate() ([]string, error)
HostnamesForDomain(string) ([]string, error)
GetPolicy(string) (models.PolicySubmission, bool, error)
}

// Called with failure by defaault.
Expand All @@ -28,7 +30,7 @@ func reportToSentry(name string, domain string, result checker.DomainResult) {
result)
}

type checkPerformer func(string, []string) checker.DomainResult
type checkPerformer func(models.PolicySubmission) checker.DomainResult
type resultCallback func(string, string, checker.DomainResult)

// Validator runs checks regularly against domain policies. This structure
Expand All @@ -47,18 +49,37 @@ type Validator struct {
OnFailure resultCallback
// OnSuccess: optional. Called when a particular policy validation succeeds.
OnSuccess resultCallback
// checkPerformer: performs the check.
checkPerformer checkPerformer
// CheckPerformer: performs the check.
CheckPerformer checkPerformer
}

func (v *Validator) checkPolicy(domain string, hostnames []string) checker.DomainResult {
if v.checkPerformer == nil {
c := checker.Checker{
Cache: checker.MakeSimpleCache(time.Hour),
func resultMTASTSToPolicy(r *checker.MTASTSResult) *policy.TLSPolicy {
return &policy.TLSPolicy{Mode: r.Mode, MXs: r.MXs}
}

func getMTASTSUpdater(update func(*models.PolicySubmission) error) checkPerformer {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be used somewhere!

c := checker.Checker{Cache: checker.MakeSimpleCache(time.Hour)}
return func(p models.PolicySubmission) checker.DomainResult {
if p.MTASTS {
result := c.CheckDomain(p.Name, []string{})
if !p.Policy.Equals(resultMTASTSToPolicy(result.MTASTSResult)) {
if err := update(&p); err != nil {
reportToSentry(fmt.Sprintf("couldn't update policy in DB: %v", err), p.Name, result)
}
}
}
return c.CheckDomain(p.Name, p.Policy.MXs)
}
}

func (v *Validator) checkPolicy(p *models.PolicySubmission) checker.DomainResult {
if v.CheckPerformer == nil {
c := checker.Checker{Cache: checker.MakeSimpleCache(time.Hour)}
v.CheckPerformer = func(policy models.PolicySubmission) checker.DomainResult {
return c.CheckDomain(p.Name, p.Policy.MXs)
}
v.checkPerformer = c.CheckDomain
}
return v.checkPerformer(domain, hostnames)
return v.CheckPerformer(*p)
}

func (v *Validator) interval() time.Duration {
Expand Down Expand Up @@ -93,12 +114,12 @@ func (v *Validator) Run() {
continue
}
for _, domain := range domains {
hostnames, err := v.Store.HostnamesForDomain(domain)
if err != nil {
policy, ok, err := v.Store.GetPolicy(domain)
if err != nil || !ok {
log.Printf("[%s validator] Could not retrieve policy for domain %s: %v", v.Name, domain, err)
continue
}
result := v.checkPolicy(domain, hostnames)
result := v.checkPolicy(&policy)
if result.Status != 0 {
log.Printf("[%s validator] %s failed; sending report", v.Name, domain)
v.policyFailed(v.Name, domain, result)
Expand All @@ -108,15 +129,3 @@ func (v *Validator) Run() {
}
}
}

// ValidateRegularly regularly runs checker.CheckDomain against a Domain-
// Hostname map. Interval specifies the interval to wait between each run.
// Failures are reported to Sentry.
func ValidateRegularly(name string, store DomainPolicyStore, interval time.Duration) {
v := Validator{
Name: name,
Store: store,
Interval: interval,
}
v.Run()
}
16 changes: 9 additions & 7 deletions validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
)

type mockDomainPolicyStore struct {
Expand All @@ -19,21 +21,21 @@ func (m mockDomainPolicyStore) DomainsToValidate() ([]string, error) {
return domains, nil
}

func (m mockDomainPolicyStore) HostnamesForDomain(domain string) ([]string, error) {
return m.hostnames[domain], nil
func (m mockDomainPolicyStore) GetPolicy(domain string) (models.PolicySubmission, bool, error) {
return models.PolicySubmission{Name: domain, Policy: &policy.TLSPolicy{Mode: "testing", MXs: m.hostnames[domain]}}, true, nil
}

func noop(_ string, _ string, _ checker.DomainResult) {}

func TestRegularValidationValidates(t *testing.T) {
called := make(chan bool)
fakeChecker := func(domain string, hostnames []string) checker.DomainResult {
fakeChecker := func(_ models.PolicySubmission) checker.DomainResult {
called <- true
return checker.DomainResult{}
}
mock := mockDomainPolicyStore{
hostnames: map[string][]string{"a": []string{"hostname"}}}
v := Validator{Store: mock, Interval: 100 * time.Millisecond, checkPerformer: fakeChecker, OnFailure: noop}
v := Validator{Store: mock, Interval: 100 * time.Millisecond, CheckPerformer: fakeChecker, OnFailure: noop}
go v.Run()

select {
Expand All @@ -46,8 +48,8 @@ func TestRegularValidationValidates(t *testing.T) {

func TestRegularValidationReportsErrors(t *testing.T) {
reports := make(chan string)
fakeChecker := func(domain string, hostnames []string) checker.DomainResult {
if domain == "fail" || domain == "error" {
fakeChecker := func(p models.PolicySubmission) checker.DomainResult {
if p.Name == "fail" || p.Name == "error" {
return checker.DomainResult{Status: 5}
}
return checker.DomainResult{Status: 0}
Expand All @@ -64,7 +66,7 @@ func TestRegularValidationReportsErrors(t *testing.T) {
"fail": []string{"hostname"},
"error": []string{"hostname"},
"normal": []string{"hostname"}}}
v := Validator{Store: mock, Interval: 100 * time.Millisecond, checkPerformer: fakeChecker,
v := Validator{Store: mock, Interval: 100 * time.Millisecond, CheckPerformer: fakeChecker,
OnFailure: fakeReporter, OnSuccess: fakeSuccessReporter,
}
go v.Run()
Expand Down