diff --git a/go.mod b/go.mod index 38cd071dce..c221e2951d 100644 --- a/go.mod +++ b/go.mod @@ -21,8 +21,6 @@ require ( github.com/c-robinson/iplib v1.0.8 github.com/cenkalti/backoff v2.2.1+incompatible github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b - github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 - github.com/emersion/go-smtp v0.24.0 github.com/google/go-cmp v0.7.0 github.com/google/go-containerregistry v0.21.2 github.com/google/subcommands v1.2.0 diff --git a/go.sum b/go.sum index 07b6bc21c8..ab5632cc52 100644 --- a/go.sum +++ b/go.sum @@ -297,10 +297,6 @@ github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0o github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= -github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 h1:oP4q0fw+fOSWn3DfFi4EXdT+B+gTtzx8GC9xsc26Znk= -github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= -github.com/emersion/go-smtp v0.24.0 h1:g6AfoF140mvW0vLNPD/LuCBLEAdlxOjIXqbIkJIS6Wk= -github.com/emersion/go-smtp v0.24.0/go.mod h1:ZtRRkbTyp2XTHCA+BmyTFTrj8xY4I+b4McvHxCU2gsQ= github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= diff --git a/reporter/email.go b/reporter/email.go index 7e8e18a78f..6bb11f2a1c 100644 --- a/reporter/email.go +++ b/reporter/email.go @@ -5,17 +5,83 @@ import ( "fmt" "net" "net/mail" + "net/smtp" "strings" "time" - sasl "github.com/emersion/go-sasl" - smtp "github.com/emersion/go-smtp" "golang.org/x/xerrors" "github.com/future-architect/vuls/config" "github.com/future-architect/vuls/models" ) +// plainAuth implements smtp.Auth for the PLAIN mechanism without +// stdlib's TLS enforcement, preserving behavioral parity with the +// previously used go-smtp library for TLSMode "None" configurations. +type plainAuth struct { + identity, username, password string +} + +func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if !server.TLS { + advertised := false + for _, mech := range server.Auth { + if strings.EqualFold(mech, "PLAIN") { + advertised = true + break + } + } + if !advertised { + return "", nil, xerrors.New("unencrypted connection: PLAIN auth requires TLS or explicit server advertisement") + } + } + resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password) + return "PLAIN", resp, nil +} + +func (a *plainAuth) Next(_ []byte, more bool) ([]byte, error) { + if more { + return nil, xerrors.New("unexpected server challenge") + } + return nil, nil +} + +// loginAuth implements smtp.Auth for the LOGIN mechanism. +type loginAuth struct { + username, password string +} + +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if !server.TLS { + advertised := false + for _, mech := range server.Auth { + if strings.EqualFold(mech, "LOGIN") { + advertised = true + break + } + } + if !advertised { + return "", nil, xerrors.New("unencrypted connection: LOGIN auth requires TLS or explicit server advertisement") + } + } + return "LOGIN", nil, nil +} + +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if !more { + return nil, nil + } + prompt := strings.TrimSpace(strings.ToLower(string(fromServer))) + switch { + case strings.Contains(prompt, "username"): + return []byte(a.username), nil + case strings.Contains(prompt, "password"): + return []byte(a.password), nil + default: + return nil, xerrors.Errorf("unexpected server challenge: %q", fromServer) + } +} + // EMailWriter send mail type EMailWriter struct { FormatOneEMail bool @@ -99,8 +165,37 @@ type emailSender struct { conf config.SMTPConf } +func (e *emailSender) dialTLS(addr string, tlsConfig *tls.Config) (*smtp.Client, error) { + conn, err := tls.Dial("tcp", addr, tlsConfig) + if err != nil { + return nil, xerrors.Errorf("Failed to create TLS connection to SMTP server: %w", err) + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + _ = conn.Close() + return nil, xerrors.Errorf("Failed to parse SMTP server address: %w", err) + } + c, err := smtp.NewClient(conn, host) + if err != nil { + _ = conn.Close() + return nil, xerrors.Errorf("Failed to create SMTP client over TLS: %w", err) + } + return c, nil +} + +func (e *emailSender) dialStartTLS(addr string, tlsConfig *tls.Config) (*smtp.Client, error) { + c, err := smtp.Dial(addr) + if err != nil { + return nil, xerrors.Errorf("Failed to create connection to SMTP server: %w", err) + } + if err := c.StartTLS(tlsConfig); err != nil { + _ = c.Close() + return nil, xerrors.Errorf("Failed to STARTTLS: %w", err) + } + return c, nil +} + func (e *emailSender) sendMail(smtpServerAddr, message string) (err error) { - var auth sasl.Client emailConf := e.conf tlsConfig := &tls.Config{ ServerName: emailConf.SMTPAddr, @@ -112,24 +207,20 @@ func (e *emailSender) sendMail(smtpServerAddr, message string) (err error) { case "": switch emailConf.SMTPPort { case "465": - c, err = smtp.DialTLS(smtpServerAddr, tlsConfig) + c, err = e.dialTLS(smtpServerAddr, tlsConfig) if err != nil { - return xerrors.Errorf("Failed to create TLS connection to SMTP server: %w", err) + return err } - defer c.Close() default: c, err = smtp.Dial(smtpServerAddr) if err != nil { return xerrors.Errorf("Failed to create connection to SMTP server: %w", err) } - defer c.Close() - if ok, _ := c.Extension("STARTTLS"); ok { - c, err = smtp.DialStartTLS(smtpServerAddr, tlsConfig) - if err != nil { - return xerrors.Errorf("Failed to create STARTTLS connection to SMTP server: %w", err) + if err := c.StartTLS(tlsConfig); err != nil { + _ = c.Close() + return xerrors.Errorf("Failed to STARTTLS: %w", err) } - defer c.Close() } } case "None": @@ -137,36 +228,36 @@ func (e *emailSender) sendMail(smtpServerAddr, message string) (err error) { if err != nil { return xerrors.Errorf("Failed to create connection to SMTP server: %w", err) } - defer c.Close() case "STARTTLS": - c, err = smtp.DialStartTLS(smtpServerAddr, tlsConfig) + c, err = e.dialStartTLS(smtpServerAddr, tlsConfig) if err != nil { - return xerrors.Errorf("Failed to create STARTTLS connection to SMTP server: %w", err) + return err } - defer c.Close() case "SMTPS": - c, err = smtp.DialTLS(smtpServerAddr, tlsConfig) + c, err = e.dialTLS(smtpServerAddr, tlsConfig) if err != nil { - return xerrors.Errorf("Failed to create TLS connection to SMTP server: %w", err) + return err } - defer c.Close() default: return xerrors.New(`invalid TLS mode. accepts: ["", "None", "STARTTLS", "SMTPS"]`) } + defer c.Close() if ok, param := c.Extension("AUTH"); ok { - authList := strings.Split(param, " ") - auth = e.newSaslClient(authList) - if err = c.Auth(auth); err != nil { - return xerrors.Errorf("Failed to authenticate: %w", err) + authList := strings.Fields(param) + auth := e.newAuth(authList) + if auth != nil { + if err = c.Auth(auth); err != nil { + return xerrors.Errorf("Failed to authenticate: %w", err) + } } } - if err = c.Mail(emailConf.From, nil); err != nil { + if err = c.Mail(emailConf.From); err != nil { return xerrors.Errorf("Failed to send Mail command: %w", err) } for _, to := range emailConf.To { - if err = c.Rcpt(to, nil); err != nil { + if err = c.Rcpt(to); err != nil { return xerrors.Errorf("Failed to send Rcpt command: %w", err) } } @@ -222,15 +313,13 @@ func NewEMailSender(cnf config.SMTPConf) EMailSender { return &emailSender{cnf} } -func (e *emailSender) newSaslClient(authList []string) sasl.Client { +func (e *emailSender) newAuth(authList []string) smtp.Auth { for _, v := range authList { - switch v { + switch strings.ToUpper(v) { case "PLAIN": - auth := sasl.NewPlainClient("", e.conf.User, e.conf.Password) - return auth + return &plainAuth{identity: "", username: e.conf.User, password: e.conf.Password} case "LOGIN": - auth := sasl.NewLoginClient(e.conf.User, e.conf.Password) - return auth + return &loginAuth{username: e.conf.User, password: e.conf.Password} } } return nil diff --git a/reporter/loginauth_test.go b/reporter/loginauth_test.go new file mode 100644 index 0000000000..124beea419 --- /dev/null +++ b/reporter/loginauth_test.go @@ -0,0 +1,170 @@ +package reporter + +import ( + "net/smtp" + "testing" +) + +func TestLoginAuthStart(t *testing.T) { + t.Parallel() + + t.Run("TLS connection succeeds", func(t *testing.T) { + t.Parallel() + auth := &loginAuth{username: "user@example.com", password: "secret"} + mech, resp, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: true}) + if err != nil { + t.Fatalf("Start() returned error: %v", err) + } + if mech != "LOGIN" { + t.Errorf("Start() mechanism = %q, want %q", mech, "LOGIN") + } + if resp != nil { + t.Errorf("Start() resp = %v, want nil", resp) + } + }) + + t.Run("non-TLS with LOGIN advertised succeeds", func(t *testing.T) { + t.Parallel() + auth := &loginAuth{username: "user@example.com", password: "secret"} + mech, _, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: false, Auth: []string{"LOGIN"}}) + if err != nil { + t.Fatalf("Start() returned error: %v", err) + } + if mech != "LOGIN" { + t.Errorf("Start() mechanism = %q, want %q", mech, "LOGIN") + } + }) + + t.Run("non-TLS without LOGIN advertised fails", func(t *testing.T) { + t.Parallel() + auth := &loginAuth{username: "user@example.com", password: "secret"} + _, _, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: false}) + if err == nil { + t.Fatal("Start() should return error for non-TLS connection without LOGIN advertised") + } + }) +} + +func TestLoginAuthNext(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fromServer string + more bool + want string + wantNil bool + }{ + {name: "username prompt", fromServer: "Username:", more: true, want: "user@example.com"}, + {name: "password prompt", fromServer: "Password:", more: true, want: "s3cret"}, + {name: "username lowercase", fromServer: "username:", more: true, want: "user@example.com"}, + {name: "password lowercase", fromServer: "password:", more: true, want: "s3cret"}, + {name: "not more", fromServer: "", more: false, wantNil: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + auth := &loginAuth{username: "user@example.com", password: "s3cret"} + got, err := auth.Next([]byte(tt.fromServer), tt.more) + if err != nil { + t.Fatalf("Next() returned error: %v", err) + } + if tt.wantNil { + if got != nil { + t.Errorf("Next() = %q, want nil", got) + } + return + } + if string(got) != tt.want { + t.Errorf("Next() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestPlainAuthStart(t *testing.T) { + t.Parallel() + + t.Run("TLS connection succeeds", func(t *testing.T) { + t.Parallel() + auth := &plainAuth{identity: "", username: "user@example.com", password: "secret"} + mech, resp, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: true}) + if err != nil { + t.Fatalf("Start() returned error: %v", err) + } + if mech != "PLAIN" { + t.Errorf("Start() mechanism = %q, want %q", mech, "PLAIN") + } + want := "\x00user@example.com\x00secret" + if string(resp) != want { + t.Errorf("Start() resp = %q, want %q", resp, want) + } + }) + + t.Run("non-TLS with PLAIN advertised succeeds", func(t *testing.T) { + t.Parallel() + auth := &plainAuth{identity: "", username: "user@example.com", password: "secret"} + mech, _, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: false, Auth: []string{"PLAIN"}}) + if err != nil { + t.Fatalf("Start() returned error: %v", err) + } + if mech != "PLAIN" { + t.Errorf("Start() mechanism = %q, want %q", mech, "PLAIN") + } + }) + + t.Run("non-TLS without PLAIN advertised fails", func(t *testing.T) { + t.Parallel() + auth := &plainAuth{identity: "", username: "user@example.com", password: "secret"} + _, _, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: false}) + if err == nil { + t.Fatal("Start() should return error for non-TLS connection without PLAIN advertised") + } + }) + + t.Run("identity is included in response", func(t *testing.T) { + t.Parallel() + auth := &plainAuth{identity: "admin", username: "user@example.com", password: "secret"} + _, resp, err := auth.Start(&smtp.ServerInfo{Name: "smtp.example.com", TLS: true}) + if err != nil { + t.Fatalf("Start() returned error: %v", err) + } + want := "admin\x00user@example.com\x00secret" + if string(resp) != want { + t.Errorf("Start() resp = %q, want %q", resp, want) + } + }) +} + +func TestPlainAuthNext(t *testing.T) { + t.Parallel() + auth := &plainAuth{identity: "", username: "user@example.com", password: "secret"} + + t.Run("no more data returns nil", func(t *testing.T) { + t.Parallel() + got, err := auth.Next(nil, false) + if err != nil { + t.Fatalf("Next() returned error: %v", err) + } + if got != nil { + t.Errorf("Next() = %q, want nil", got) + } + }) + + t.Run("more data returns error", func(t *testing.T) { + t.Parallel() + _, err := auth.Next([]byte("challenge"), true) + if err == nil { + t.Fatal("Next() should return error when server sends unexpected challenge") + } + }) +} + +func TestLoginAuthNextUnexpectedChallenge(t *testing.T) { + t.Parallel() + auth := &loginAuth{username: "user", password: "pass"} + _, err := auth.Next([]byte("Something unexpected:"), true) + if err == nil { + t.Fatal("Next() should return error for unexpected challenge") + } +}