Skip to content

Commit ac8cbdb

Browse files
Merge pull request #11 from Veridise/nikos/async-artifact-improvements
Add script to wait and download an artifact given its step-code and name.
2 parents ecebc64 + fd0c9e8 commit ac8cbdb

File tree

11 files changed

+228
-70
lines changed

11 files changed

+228
-70
lines changed

audithub_client/__main__.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_version_via_local_archive,
1313
)
1414
from .scripts.create_version_via_url import create_version_via_url # noqa
15+
from .scripts.download_artifact import download_artifact # noqa
1516
from .scripts.get_configuration import get_configuration # noqa
1617
from .scripts.get_my_organizations import get_my_organizations # noqa
1718
from .scripts.get_my_profile import get_my_profile # noqa
@@ -31,6 +32,18 @@
3132
)
3233

3334

35+
class LevelFormatter(logging.Formatter):
36+
def __init__(self, formats, default_fmt=None):
37+
super().__init__()
38+
self.formats = formats
39+
self.default_fmt = default_fmt or "%(levelname)s: %(message)s"
40+
41+
def format(self, record):
42+
fmt = self.formats.get(record.levelno, self.default_fmt)
43+
formatter = logging.Formatter(fmt)
44+
return formatter.format(record)
45+
46+
3447
@app.meta.default
3548
def meta(
3649
*tokens: Annotated[str, Parameter(show=False, allow_leading_hyphen=True)],
@@ -66,15 +79,31 @@ def meta(
6679
),
6780
] = "INFO",
6881
):
82+
# Leave all levels except DEBUG "simple", without outputting the module name.
83+
# At debug level, also show the module name that produces the log message.
84+
default_log_format = "%(asctime)s %(levelname)s %(message)s"
85+
debug_log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
86+
87+
# Setup logging as usual
6988
logging.basicConfig(
7089
level=log_level,
71-
# format="%(asctime)s.%(msecs)03d %(filename)s%(name)s %(levelname)s %(message)s", # cspell:disable-line
72-
format="%(asctime)s %(levelname)s %(message)s", # cspell:disable-line
90+
format=default_log_format,
7391
datefmt="%H:%M:%S",
7492
stream=sys.stderr,
7593
)
76-
httpx_logger = logging.getLogger("httpx")
77-
httpx_logger.setLevel(logging.WARNING)
94+
# Now modify the root logger
95+
root_logger = logging.getLogger()
96+
handler = root_logger.handlers[0] # basicConfig creates one handler by default
97+
handler.setFormatter(
98+
LevelFormatter(
99+
{logging.DEBUG: debug_log_format}, default_fmt=default_log_format
100+
)
101+
)
102+
103+
# For these modules, raise the log level to reduce noise
104+
for module in ["httpx", "httpcore"]:
105+
module_logger = logging.getLogger(module)
106+
module_logger.setLevel(logging.WARNING)
78107

79108
command, bound, _ignored = app.parse_args(tokens)
80109
# When this script runs with no args, help_print is automatically invoked

audithub_client/api/get_version_comments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
from dataclasses import dataclass
3-
from datetime import datetime, timedelta, timezone
3+
from datetime import datetime
44
from typing import Optional
55

66
from audithub_client.library.utils import get_dict_of_fields_except

audithub_client/library/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def get_access_token(
3232
"scope": "openid profile",
3333
"grant_type": "client_credentials",
3434
}
35-
logger.debug("Payload is %s", payload)
35+
# logger.debug("Payload is %s", payload)
3636
response = post(token_url, data=payload, timeout=AUTHENTICATION_TIMEOUT)
3737
if response.status_code != 200:
3838
raise RuntimeError(
3939
f'Failed to get token for client {payload["client_id"]} status = {response.status_code} response ={response.text}'
4040
)
4141
token_data = response.json()
4242
end_time = time.perf_counter()
43-
logger.debug(json.dumps(token_data, indent=4))
43+
logger.debug(json.dumps(token_data))
4444
if token_time_listener:
4545
token_time_listener(end_time - begin_time)
4646
return token_data["access_token"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pathlib import Path
2+
from typing import Tuple
3+
4+
from httpx import Client
5+
6+
from audithub_client.library.net_utils import download_file
7+
8+
from .auth import DEFAULT_REQUEST_TIMEOUT
9+
10+
11+
def download_from_url(
12+
url: str, output_file: Path, timeout=DEFAULT_REQUEST_TIMEOUT
13+
) -> Tuple[int, str]:
14+
"""This is a `curl -o`/`wget -O` equivalent"""
15+
with Client(timeout=timeout) as client:
16+
with client.stream("GET", url) as response:
17+
return download_file(response, output_file)

audithub_client/library/invocation_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,5 @@
6060
help="An optional task name for this task. If not specified, one will automatically be generated by AuditHub."
6161
),
6262
]
63+
64+
BooleanArg = Annotated[bool, Parameter(negative_bool=())]

audithub_client/library/json_dump.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
from tabulate import tabulate
77

88
OutputType = Annotated[
9-
Literal["raw", "json", "json-pretty", "pprint", "list", "table"],
9+
Literal["raw", "json", "json-pretty", "pprint", "list", "table", "none"],
1010
Parameter(
1111
help="""\
12-
The output format. Options are: 'raw': Python print(), 'json': single-line JSON, 'json-pretty': multi-line JSON, 'pprint': Python pprint(), 'list': list element per line, 'table': tabular view.
12+
The output format. Options are: 'raw': Python print(), 'json': single-line JSON, 'json-pretty': multi-line JSON, 'pprint': Python pprint(), 'list': list element per line, 'table': tabular view, 'none': omit output completely.
1313
"""
1414
),
1515
]
1616

1717

1818
def dump_dict(document: Any, section: str | None = None, output: OutputType = "json"):
19-
19+
if output == "none":
20+
return
2021
if section == "sections":
2122
document = sorted(document.keys())
2223
elif section:
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import logging
2+
import sys
3+
from pathlib import Path
4+
from time import sleep
5+
from typing import Annotated
6+
7+
from cyclopts import Parameter, validators
8+
9+
from ..api.get_task_info import GetTaskInfoArgs, api_get_task_info
10+
from ..library.http_download import download_from_url
11+
from ..library.invocation_common import (
12+
AuditHubContextType,
13+
OrganizationIdType,
14+
TaskIdType,
15+
app,
16+
)
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@app.command
22+
def download_artifact(
23+
*,
24+
organization_id: OrganizationIdType,
25+
task_id: TaskIdType,
26+
step_code: str,
27+
name: str,
28+
output_file: Path,
29+
timeout: Annotated[int, Parameter(validator=validators.Number(gt=0))] = 30,
30+
rpc_context: AuditHubContextType,
31+
):
32+
"""
33+
Download an artifact by name, potentially waiting for it to become available.
34+
35+
This is an improved version of get-task-artifact, that potentially waits for the artifact to become
36+
available and downloads it using the provided URL.
37+
38+
Note that you can use 'ah get-task-info' to obtain the list of step_codes
39+
(key 'steps') and produced artifacts (key 'artifacts').
40+
41+
Also note that, due to the asynchronous nature of AuditHub, artifacts may take a short amount of time
42+
until they become available, even when the task has finished.
43+
This is normal, and this command takes this into account.
44+
45+
Parameters
46+
----------
47+
step_code:
48+
The code of the workflow step that produced the artifact
49+
name:
50+
The name of the artifact.
51+
output_file:
52+
The local file name to store the output in.
53+
timeout:
54+
The number of seconds to potentially wait for the artifact to become available.
55+
"""
56+
try:
57+
found = False
58+
for attempt in range(1, timeout + 1):
59+
logger.debug("Starting attempt %d", attempt)
60+
task_info = api_get_task_info(
61+
rpc_context,
62+
GetTaskInfoArgs(organization_id=organization_id, task_id=task_id),
63+
)
64+
matched_artifacts = [
65+
e
66+
for e in task_info.get("artifacts", list())
67+
if e.get("step_code") == step_code and e.get("name") == name
68+
]
69+
if len(matched_artifacts) > 1:
70+
# This should be impossible
71+
raise RuntimeError(
72+
f"Multiple artifacts matched the condition, bailing: '{matched_artifacts}'"
73+
)
74+
if len(matched_artifacts) == 1:
75+
logger.debug("Artifact found at attempt %d, downloading...", attempt)
76+
bytes_written, hr_size = download_from_url(
77+
matched_artifacts[0]["presigned_url"], output_file
78+
)
79+
logger.info(
80+
f"Downloaded {bytes_written} bytes ({hr_size}) as {output_file}."
81+
)
82+
found = True
83+
break
84+
else:
85+
logger.info(
86+
"Artifact not found, waiting a sec at attempt %d..", attempt
87+
)
88+
sleep(1)
89+
90+
if found:
91+
logger.debug("Finished.")
92+
else:
93+
logger.error("Artifact not found (yet?).")
94+
sys.exit(1)
95+
96+
except Exception as ex:
97+
logger.error("Error %s", str(ex), exc_info=ex)

audithub_client/scripts/get_task_artifact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_task_artifact(
2323
rpc_context: AuditHubContextType,
2424
):
2525
"""
26-
DownGet logs of a task's step.
26+
Download a task's artifact, by its id.
2727
2828
Parameters
2929
----------

audithub_client/scripts/get_task_info.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ..api.get_task_info import GetTaskInfoArgs, api_get_task_info
66
from ..library.invocation_common import (
77
AuditHubContextType,
8+
BooleanArg,
89
OrganizationIdType,
910
TaskIdType,
1011
app,
@@ -21,7 +22,8 @@ def get_task_info(
2122
organization_id: OrganizationIdType,
2223
task_id: TaskIdType,
2324
output: OutputType = "json",
24-
verify: bool = False,
25+
verify: BooleanArg = False,
26+
check_completed: BooleanArg = False,
2527
rpc_context: AuditHubContextType,
2628
):
2729
"""
@@ -37,6 +39,9 @@ def get_task_info(
3739
If true, the findings counters are summed. If the sum if zero, the exit code is 0, otherwise it is 1.
3840
i.e., an exit code of 1 means there is at least one finding in one of the findings categories.
3941
This argument is independent of any output arguments.
42+
check_completed:
43+
If true, all steps that are "tools" are checked to see if they reported completed == false.
44+
In such a case, we produce a 'timeout error' with an exit code of 2
4045
"""
4146
try:
4247
rpc_input = GetTaskInfoArgs(organization_id=organization_id, task_id=task_id)
@@ -60,6 +65,18 @@ def get_task_info(
6065
exit_code = 1
6166
if exit_code == 0:
6267
print("No findings reported by this task, exiting with 0")
68+
if exit_code == 0 and check_completed:
69+
for step in task_info.get("steps", list()):
70+
completed = step.get("completed_without_timeout", None)
71+
if completed is not None:
72+
if not completed:
73+
exit_code = 2
74+
print(
75+
"Step",
76+
step.get("code"),
77+
"reported that it hit a timeout and could not complete its work",
78+
)
79+
6380
logger.debug("Finished.")
6481
sys.exit(exit_code)
6582
except Exception as ex:

0 commit comments

Comments
 (0)