diff --git a/pyproject.toml b/pyproject.toml index 7ff90d9..00320f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ 'prettytable', 'GitPython', 'packaging', + 'socket-sdk-python>=1.0.15,<2.0.0' ] readme = "README.md" description = "Socket Security CLI for CI/CD" @@ -45,4 +46,4 @@ include = [ ] [tool.setuptools.dynamic] -version = {attr = "socketsecurity.__version__"} \ No newline at end of file +version = {attr = "socketsecurity.__version__"} diff --git a/requirements.txt b/requirements.txt index 896774a..885c5b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +socket-sdk-python>=1.0.15,<2.0.0 requests>=2.32.0 mdutils~=1.6.0 prettytable diff --git a/socketsecurity/__init__.py b/socketsecurity/__init__.py index 48e740d..40a32c8 100644 --- a/socketsecurity/__init__.py +++ b/socketsecurity/__init__.py @@ -1,2 +1,2 @@ __author__ = 'socket.dev' -__version__ = '1.0.41' +__version__ = '1.0.42' diff --git a/socketsecurity/core/__init__.py b/socketsecurity/core/__init__.py index 4855b51..4b71606 100644 --- a/socketsecurity/core/__init__.py +++ b/socketsecurity/core/__init__.py @@ -1,36 +1,39 @@ +import base64 +import json import logging +import platform +import time +from glob import glob from pathlib import PurePath -from requests.exceptions import ReadTimeout -import requests from urllib.parse import urlencode -import base64 -import json -from socketsecurity.core.exceptions import ( - APIFailure, - APIKeyMissing, - APIAccessDenied, - APIInsufficientQuota, - APIResourceNotFound, - APICloudflareError, - RequestTimeoutExceeded -) + +import requests +from requests.exceptions import ReadTimeout +from socketdev import socketdev + from socketsecurity import __version__ -from socketsecurity.core.licenses import Licenses -from socketsecurity.core.issues import AllIssues from socketsecurity.core.classes import ( - Report, - Issue, - Package, Alert, + Diff, FullScan, FullScanParams, + Issue, + Package, + Purl, + Report, Repository, - Diff, - Purl ) -import platform -from glob import glob -import time +from socketsecurity.core.exceptions import ( + APIAccessDenied, + APICloudflareError, + APIFailure, + APIInsufficientQuota, + APIKeyMissing, + APIResourceNotFound, + RequestTimeoutExceeded, +) +from socketsecurity.core.issues import AllIssues +from socketsecurity.core.licenses import Licenses __all__ = [ "Core", @@ -55,6 +58,8 @@ log = logging.getLogger("socketdev") log.addHandler(logging.NullHandler()) +socket_sdk = None + socket_globs = { "spdx": { "spdx.json": { @@ -153,6 +158,15 @@ def encode_key(token: str) -> None: encoded_key = base64.b64encode(token.encode()).decode('ascii') +class SCMRequestError(Exception): + """Generic exception for SCM API request failures""" + def __init__(self, status_code: int, message: str, url: str): + self.status_code = status_code + self.message = message + self.url = url + super().__init__(f"SCM API request failed: {status_code} - {message} (URL: {url})") + + def do_request( path: str, headers: dict = None, @@ -160,34 +174,24 @@ def do_request( files: list = None, method: str = "GET", base_url: str = None, -) -> requests.request: +) -> requests.Response: """ - do_requests is the shared function for making HTTP calls - :param base_url: + Shared function for making HTTP calls to SCM providers (GitHub/GitLab) + :param base_url: Base URL for the SCM provider API :param path: Required path for the request - :param headers: Optional dictionary of headers. If not set will use a default set + :param headers: Optional dictionary of headers :param payload: Optional dictionary or string of the payload to pass :param files: Optional list of files to upload :param method: Optional method to use, defaults to GET - :return: + :return: Response object + :raises: SCMRequestError if the request fails """ + if base_url is None: + raise ValueError("base_url is required for SCM API calls") + + url = f"{base_url}/{path}" + verify = not allow_unverified_ssl - if base_url is not None: - url = f"{base_url}/{path}" - else: - if encoded_key is None or encoded_key == "": - raise APIKeyMissing - url = f"{api_url}/{path}" - - if headers is None: - headers = { - 'Authorization': f"Basic {encoded_key}", - 'User-Agent': f'SocketPythonCLI/{__version__}', - "accept": "application/json" - } - verify = True - if allow_unverified_ssl: - verify = False try: response = requests.request( method.upper(), @@ -198,43 +202,49 @@ def do_request( timeout=timeout, verify=verify ) - except ReadTimeout: - raise RequestTimeoutExceeded(f"Configured timeout {timeout} reached for request for path {url}") - output_headers = headers.copy() - output_headers['Authorization'] = "API_KEY_REDACTED" - output = { - "url": url, - "headers": output_headers, - "status_code": response.status_code, - "body": response.text, - "payload": payload, - "files": files, - "timeout": timeout - } - log.debug(output) - if response.status_code <= 399: - return response - elif response.status_code == 400: - raise APIFailure(output) - elif response.status_code == 401: - raise APIAccessDenied("Unauthorized") - elif response.status_code == 403: - raise APIInsufficientQuota("Insufficient max_quota for API method") - elif response.status_code == 404: - raise APIResourceNotFound(f"Path not found {path}") - elif response.status_code == 429: - raise APIInsufficientQuota("Insufficient quota for API route") - elif response.status_code == 524: - raise APICloudflareError(response.text) - else: - msg = { + + # Log request details (with redacted auth) + output_headers = headers.copy() if headers else {} + if 'Authorization' in output_headers: + output_headers['Authorization'] = "TOKEN_REDACTED" + + log.debug({ + "url": url, + "headers": output_headers, "status_code": response.status_code, - "UnexpectedError": "There was an unexpected error using the API", - "error": response.text, + "body": response.text, "payload": payload, - "url": url - } - raise APIFailure(msg) + "files": files, + "timeout": timeout + }) + + if response.status_code < 400: + return response + + # Try to get error message from response + try: + error_msg = response.json().get('message', response.text) + except (json.JSONDecodeError, AttributeError): + error_msg = response.text + + raise SCMRequestError( + status_code=response.status_code, + message=error_msg, + url=url + ) + + except ReadTimeout: + raise SCMRequestError( + status_code=408, + message=f"Request timed out after {timeout} seconds", + url=url + ) + except requests.RequestException as e: + raise SCMRequestError( + status_code=500, + message=str(e), + url=url + ) class Core: @@ -251,9 +261,10 @@ def __init__( enable_all_alerts: bool = False, allow_unverified: bool = False ): - global allow_unverified_ssl + global allow_unverified_ssl, socket_sdk allow_unverified_ssl = allow_unverified self.token = token + ":" + socket_sdk = socketdev(self.token, timeout=request_timeout) encode_key(self.token) self.socket_date_format = "%Y-%m-%dT%H:%M:%S.%fZ" self.base_api_url = base_api_url @@ -311,9 +322,7 @@ def get_org_id_slug() -> (str, str): Gets the Org ID and Org Slug for the API Token :return: """ - path = "organizations" - response = do_request(path) - data = response.json() + data = socket_sdk.org.get() organizations = data.get("organizations") new_org_id = None new_org_slug = None @@ -325,23 +334,30 @@ def get_org_id_slug() -> (str, str): @staticmethod def get_sbom_data(full_scan_id: str) -> list: - path = f"orgs/{org_slug}/full-scans/{full_scan_id}" - response = do_request(path) - results = [] + """ + Gets SBOM data for a full scan using the Socket SDK + :param full_scan_id: str - ID of the full scan to get SBOM data for + :return: list of SBOM artifacts + """ try: - data = response.json() - results = data.get("sbom_artifacts") or [] + result = socket_sdk.fullscans.stream(org_slug, full_scan_id) + if result.get("success", False): + # Remove metadata properties before returning artifacts + result.pop("success", None) + result.pop("status", None) + # The SDK returns a dict with the SBOM artifacts as values, so we need to convert it to a list + return list(result.values()) + else: + # TODO: In future ticket, throw appropriate error here instead of returning empty list + log.error(f"Failed to get SBOM data for scan {full_scan_id}") + log.error(f"Status: {result.get('status')}") + log.error(f"Message: {result.get('message')}") + return [] except Exception as error: - log.debug("Failed with old style full-scan API using new format") - log.debug(error) - data = response.text - data.strip('"') - data.strip() - for line in data.split("\n"): - if line != '"' and line != "" and line is not None: - item = json.loads(line) - results.append(item) - return results + # TODO: In future ticket, throw appropriate error here instead of returning empty list + log.error(f"Unexpected error getting SBOM data for scan {full_scan_id}") + log.error(error) + return [] @staticmethod def get_security_policy() -> dict: @@ -349,14 +365,7 @@ def get_security_policy() -> dict: Get the Security policy and determine the effective Org security policy :return: """ - path = "settings" - payload = [ - { - "organization": org_id - } - ] - response = do_request(path, payload=json.dumps(payload), method="POST") - data = response.json() + data = socket_sdk.settings.get(org_id) defaults = data.get("defaults") default_rules = defaults.get("issueRules") entries = data.get("entries") @@ -400,16 +409,15 @@ def get_manifest_files(package: Package, packages: dict) -> str: @staticmethod def create_sbom_output(diff: Diff) -> dict: - base_path = f"orgs/{org_slug}/export/cdx" - path = f"{base_path}/{diff.id}" - result = do_request(path=path) - try: - sbom = result.json() - except Exception as error: + result = socket_sdk.export.cdx_bom(org_slug, diff.id) + + if not result.get("success", False): log.error(f"Unable to get CycloneDX Output for {diff.id}") - log.error(error) - sbom = {} - return sbom + log.error(result.get("message", "No error message provided")) + return {} + + result.pop("success", None) + return result @staticmethod def match_supported_files(files: list) -> bool: @@ -462,6 +470,7 @@ def find_files(path: str) -> list: end_time = time.time() total_time = end_time - start_time log.info(f"Found {len(files)} in {total_time:.2f} seconds") + log.debug(f"Files found: {list(files)}") return list(files) @staticmethod @@ -473,38 +482,18 @@ def create_full_scan(files: list, params: FullScanParams, workspace: str) -> Ful :param workspace: str - Path of workspace :return: """ - send_files = [] create_full_start = time.time() log.debug("Creating new full scan") - for file in files: - if platform.system() == "Windows": - file = file.replace("\\", "/") - if "/" in file: - path, name = file.rsplit("/", 1) - else: - path = "." - name = file - full_path = f"{path}/{name}" - if full_path.startswith(workspace): - key = full_path[len(workspace):] - else: - key = full_path - key = key.lstrip("/") - key = key.lstrip("./") - payload = ( - key, - ( - name, - open(full_path, 'rb') - ) - ) - send_files.append(payload) - query_params = urlencode(params.__dict__) - full_uri = f"{full_scan_path}?{query_params}" - response = do_request(full_uri, method="POST", files=send_files) - results = response.json() + + # Convert params to dict and add org_slug + params_dict = params.__dict__.copy() + params_dict['org_slug'] = org_slug + + results = socket_sdk.fullscans.post(files=files, params=params_dict, workspace=workspace) + full_scan = FullScan(**results) full_scan.sbom_artifacts = Core.get_sbom_data(full_scan.id) + create_full_end = time.time() total_time = create_full_end - create_full_start log.debug(f"New Full Scan created in {total_time:.2f} seconds") @@ -527,9 +516,8 @@ def get_head_scan_for_repo(repo_slug: str): :param repo_slug: Str - Repo slug for the repository that is being diffed :return: """ - repo_path = f"{repository_path}/{repo_slug}" - response = do_request(repo_path) - results = response.json() + results = socket_sdk.repos.repo(org_slug, repo_name=repo_slug) + repository = Repository(**results) return repository.head_full_scan_id @@ -540,9 +528,7 @@ def get_full_scan(full_scan_id: str) -> FullScan: :param full_scan_id: str - ID of the full scan to pull :return: """ - full_scan_url = f"{full_scan_path}/{full_scan_id}" - response = do_request(full_scan_url) - results = response.json() + results = socket_sdk.fullscans.metadata(org_slug, full_scan_id) full_scan = FullScan(**results) full_scan.sbom_artifacts = Core.get_sbom_data(full_scan.id) return full_scan @@ -876,6 +862,18 @@ def create_sbom_dict(sbom: list) -> dict: log.debug(f"Orphaned top level package id {package_id} for packages {details}") else: packages[package_id].transitives = top_level_count[package_id] + + # Check for potential API truncation + top_levels_len = len(top_levels) + packages_len = len(packages) + difference = top_levels_len - packages_len + + if difference > 10 and difference > (packages_len * 0.5): + raise APIFailure( + f"Potential API truncation detected: Found {top_levels_len} top-level ancestors but only {packages_len} packages. " + f"This suggests the SBOM data may be incomplete." + ) + return packages @staticmethod diff --git a/socketsecurity/socketcli.py b/socketsecurity/socketcli.py index ac9dd48..a0da7d8 100644 --- a/socketsecurity/socketcli.py +++ b/socketsecurity/socketcli.py @@ -1,5 +1,6 @@ import argparse import json +import traceback import socketsecurity.core from socketsecurity.core import Core, __version__ @@ -169,9 +170,6 @@ type=float ) - - - def output_console_comments(diff_report: Diff, sbom_file_name: str = None) -> None: if diff_report.id != "NO_DIFF_RAN": console_security_comment = Messages.create_console_security_alert_table(diff_report) @@ -231,6 +229,8 @@ def cli(): except Exception as error: log.error("Unexpected error when running the cli") log.error(error) + log.error("Traceback:") + log.error(traceback.format_exc()) if not blocking_disabled: sys.exit(3) else: