diff --git a/__snapshots__/main_test.snap b/__snapshots__/main_test.snap index 9af4fbe6..76f66201 100755 --- a/__snapshots__/main_test.snap +++ b/__snapshots__/main_test.snap @@ -58,6 +58,44 @@ You must provide at least one path to either a lockfile or a directory containin --- +[TestRun_APIError/#00 - 1] +Loaded the following OSV databases: + api#http://: (using batches of 1000) + RubyGems (%% vulnerabilities, including withdrawn - last updated %%) + +/Gemfile.lock: found 1 package + Using config at /.osv-detector.yml (0 ignores) + Using db api#http://: (using batches of 1000) + Using db RubyGems (%% vulnerabilities, including withdrawn - last updated %%) + + no known vulnerabilities found + +--- + +[TestRun_APIError/#00 - 2] + an api error occurred while trying to check the packages listed in /Gemfile.lock: api returned unexpected status (POST http://:/querybatch 400) + +--- + +[TestRun_APIError/#01 - 1] +Loaded the following OSV databases: + api#http://: (using batches of 1000) + RubyGems (%% vulnerabilities, including withdrawn - last updated %%) + +/Gemfile.lock: found 1 package + Using config at /.osv-detector.yml (0 ignores) + Using db api#http://: (using batches of 1000) + Using db RubyGems (%% vulnerabilities, including withdrawn - last updated %%) + + no known vulnerabilities found + +--- + +[TestRun_APIError/#01 - 2] + an api error occurred while trying to check the packages listed in /Gemfile.lock: api response could not be parsed as json (POST http://:/querybatch): invalid character '<' looking for beginning of value + +--- + [TestRun_Configs/#00 - 1] Loaded the following OSV databases: diff --git a/internal/reporter/reporter.go b/internal/reporter/reporter.go index a3486455..f3a7adb9 100644 --- a/internal/reporter/reporter.go +++ b/internal/reporter/reporter.go @@ -16,6 +16,8 @@ type Reporter struct { stderr io.Writer outputAsJSON bool results []Result + + hasErrored bool } func New(stdout io.Writer, stderr io.Writer, outputAsJSON bool) *Reporter { @@ -27,9 +29,15 @@ func New(stdout io.Writer, stderr io.Writer, outputAsJSON bool) *Reporter { } } +func (r *Reporter) HasErrored() bool { + return r.hasErrored +} + // PrintErrorf writes the given message to stderr, regardless of if the reporter // is outputting as JSON or not func (r *Reporter) PrintErrorf(msg string, a ...any) { + r.hasErrored = true + fmt.Fprintf(r.stderr, msg, a...) } diff --git a/main.go b/main.go index faa1ea4e..21ef194e 100644 --- a/main.go +++ b/main.go @@ -733,6 +733,10 @@ This flag can be passed multiple times to ignore different vulnerabilities`) writeUpdatedConfigs(r, vulnsPerConfig) } + if r.HasErrored() && exitCode == 0 { + exitCode = 127 + } + return exitCode } diff --git a/main_normalize_test.go b/main_normalize_test.go index 00bf656f..2271a782 100644 --- a/main_normalize_test.go +++ b/main_normalize_test.go @@ -89,6 +89,16 @@ func normalizeDatabaseStats(t *testing.T, str string) string { return re.ReplaceAllString(str, "$1 (%% vulnerabilities, including withdrawn - last updated %%)") } +// normalizeLocalhostPort attempts to replace references to 127.0.0.1: +// with a placeholder, to ensure tests pass when using httptest.Server +func normalizeLocalhostPort(t *testing.T, str string) string { + t.Helper() + + re := cachedregexp.MustCompile(`127\.0\.0\.1:\d+`) + + return re.ReplaceAllString(str, ":") +} + // normalizeErrors attempts to replace error messages on alternative OSs with their // known linux equivalents, to ensure tests pass across different OSs func normalizeErrors(t *testing.T, str string) string { @@ -110,6 +120,7 @@ func normalizeSnapshot(t *testing.T, str string) string { normalizeTempDirectory, normalizeUserCacheDirectory, normalizeDatabaseStats, + normalizeLocalhostPort, normalizeErrors, } { str = normalizer(t, str) diff --git a/main_test.go b/main_test.go index f1c12a50..798434dd 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,8 @@ package main import ( "bytes" "fmt" + "net/http" + "net/http/httptest" "os" "path/filepath" "strings" @@ -883,3 +885,70 @@ func TestRun_EndToEnd(t *testing.T) { }) } } + +func TestRun_APIError(t *testing.T) { + t.Parallel() + + tests := []struct{ handler http.HandlerFunc }{ + { + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("{}")) + }, + }, + { + handler: func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("")) + }, + }, + } + for _, tt := range tests { + t.Run("", func(t *testing.T) { + t.Parallel() + + //nolint:usetesting // we need to customize the directory name to replace in snapshots + p, err := os.MkdirTemp("", "osv-detector-test-*") + if err != nil { + t.Fatalf("could not create test directory: %v", err) + } + + // ensure the test directory is removed when we're done testing + t.Cleanup(func() { + _ = os.RemoveAll(p) + }) + + // create a file for scanning + err = os.WriteFile(filepath.Join(p, "Gemfile.lock"), []byte(` +GEM + remote: https://rubygems.org/ + specs: + ast (2.4.2) +`), 0600) + + if err != nil { + t.Fatal(err) + } + + // setup a fake api server + ts := httptest.NewServer(tt.handler) + t.Cleanup(ts.Close) + + // create a config file setting up our api server + err = os.WriteFile(filepath.Join(p, ".osv-detector.yml"), []byte(` +extra-databases: + - url: `+ts.URL, + ), 0600) + + if err != nil { + t.Fatal(err) + } + + // run the cli in our tmp directory + testCli(t, cliTestCase{ + name: "", + args: []string{p}, + exit: 127, + }) + }) + } +}