Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion cvelib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def show_cve(
@click.option(
"--state",
type=click.Choice(CveApi.States.values(), case_sensitive=False),
help="Filter by reservation state.",
help="Filter by reservation/record state.",
)
@click.option(
"--reserved-lt", type=click.DateTime(), help="Filter by reservation time before timestamp."
Expand Down Expand Up @@ -877,6 +877,34 @@ def list_cves(
print_table(lines, highlight_header=not no_header)


@cli.command()
@click.option(
"--state",
type=click.Choice(
[str(CveApi.States.PUBLISHED), str(CveApi.States.REJECTED)], case_sensitive=False
),
help="Filter count by record state.",
)
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def count(ctx: click.Context, state: Optional[str], print_raw: bool) -> None:
"""Display the total count of CVE records, optionally filtered by state.

This retrieves the count of all CVE records across all organizations, and can be
filtered by state (PUBLISHED, REJECTED).
"""
cve_api = ctx.obj.cve_api
count_data = cve_api.count_cves(state=state)

if print_raw:
print_json_data(count_data)
return

state_text = f" in {state.upper()} state" if state else ""
click.echo(f"Total CVE records{state_text}: {count_data['totalCount']}")


@cli.command()
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
Expand Down
10 changes: 10 additions & 0 deletions cvelib/cve_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def list_cves(
params["time_reserved.gt"] = reserved_gt.isoformat()
return self._get_paged("cve-id", page_data_attr="cve_ids", params=params)

def count_cves(self, state: Optional[str] = None) -> dict:
"""Return the count of CVE records, optionally filtered by state.

Only RESERVED and PUBLISHED CVE records can be counted.
"""
params = {}
if state:
params["state"] = state.upper()
return self._get("cve_count", params=params).json()

def quota(self) -> dict:
return self._get(f"org/{self.org}/id_quota").json()

Expand Down
26 changes: 26 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ def test_cve_list():
)


def test_count():
count_response = {"totalCount": 123}
with mock.patch("cvelib.cli.CveApi.count_cves") as count_cves:
count_cves.return_value = count_response

# No state filter
runner = CliRunner()
result = runner.invoke(cli, DEFAULT_OPTS + ["count"])
assert result.exit_code == 0, result.output
assert result.output == "Total CVE records: 123\n"
count_cves.assert_called_with(state=None)

# With state filter
runner = CliRunner()
result = runner.invoke(cli, DEFAULT_OPTS + ["count", "--state", "published"])
assert result.exit_code == 0, result.output
assert result.output == "Total CVE records in PUBLISHED state: 123\n"
count_cves.assert_called_with(state="PUBLISHED")

# Raw output
runner = CliRunner()
result = runner.invoke(cli, DEFAULT_OPTS + ["count", "--raw"])
assert result.exit_code == 0, result.output
assert json.loads(result.output) == count_response


class TestCvePublish:
cve_id = "CVE-2001-0635"
cna_dict = {
Expand Down
14 changes: 14 additions & 0 deletions tests/test_cve_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@ def test_generator_not_overridden(self, sample_cve_json):
result = CveApi._add_generator(cve_json)
assert "x_generator" in result
assert result["x_generator"] == original_value


def test_count_cves():
with mock.patch("cvelib.cve_api.CveApi._get") as get_mock:
get_mock.return_value.json.return_value = {"totalCount": 42}
cve_api = CveApi(username="test_user", org="test_org", api_key="test_key")

count = cve_api.count_cves()
get_mock.assert_called_with("cve_count", params={})
assert count == {"totalCount": 42}

count = cve_api.count_cves(state="published")
get_mock.assert_called_with("cve_count", params={"state": "PUBLISHED"})
assert count == {"totalCount": 42}