diff --git a/socketdev/__init__.py b/socketdev/__init__.py index 7ddb98b..4cc3455 100644 --- a/socketdev/__init__.py +++ b/socketdev/__init__.py @@ -4,6 +4,7 @@ from socketdev.dependencies import Dependencies from socketdev.export import Export from socketdev.fullscans import FullScans +from socketdev.historical import Historical from socketdev.npm import NPM from socketdev.openapi import OpenAPI from socketdev.org import Orgs @@ -14,6 +15,7 @@ from socketdev.repositories import Repositories from socketdev.sbom import Sbom from socketdev.settings import Settings +from socketdev.triage import Triage from socketdev.utils import Utils, IntegrationType, INTEGRATION_TYPES from socketdev.version import __version__ @@ -23,7 +25,6 @@ __all__ = ["socketdev", "Utils", "IntegrationType", "INTEGRATION_TYPES"] - global encoded_key encoded_key: str @@ -32,6 +33,8 @@ log = logging.getLogger("socketdev") log.addHandler(logging.NullHandler()) +# TODO: Add debug flag to constructor to enable verbose error logging for API response parsing. + class socketdev: def __init__(self, token: str, timeout: int = 1200): @@ -41,18 +44,20 @@ def __init__(self, token: str, timeout: int = 1200): self.api.set_timeout(timeout) self.dependencies = Dependencies(self.api) + self.export = Export(self.api) + self.fullscans = FullScans(self.api) + self.historical = Historical(self.api) self.npm = NPM(self.api) self.openapi = OpenAPI(self.api) self.org = Orgs(self.api) + self.purl = Purl(self.api) self.quota = Quota(self.api) self.report = Report(self.api) - self.sbom = Sbom(self.api) - self.purl = Purl(self.api) - self.fullscans = FullScans(self.api) - self.export = Export(self.api) - self.repositories = Repositories(self.api) self.repos = Repos(self.api) + self.repositories = Repositories(self.api) + self.sbom = Sbom(self.api) self.settings = Settings(self.api) + self.triage = Triage(self.api) self.utils = Utils() @staticmethod diff --git a/socketdev/core/api.py b/socketdev/core/api.py index 220ada3..f2ecc9d 100644 --- a/socketdev/core/api.py +++ b/socketdev/core/api.py @@ -2,9 +2,16 @@ import requests from socketdev.core.classes import Response from socketdev.exceptions import ( - APIKeyMissing, APIFailure, APIAccessDenied, APIInsufficientQuota, - APIResourceNotFound, APITimeout, APIConnectionError, APIBadGateway, - APIInsufficientPermissions, APIOrganizationNotAllowed + APIKeyMissing, + APIFailure, + APIAccessDenied, + APIInsufficientQuota, + APIResourceNotFound, + APITimeout, + APIConnectionError, + APIBadGateway, + APIInsufficientPermissions, + APIOrganizationNotAllowed, ) from socketdev.version import __version__ from requests.exceptions import Timeout, ConnectionError @@ -24,7 +31,12 @@ def set_timeout(self, timeout: int): self.request_timeout = timeout def do_request( - self, path: str, headers: dict | None = None, payload: [dict, str] = None, files: list = None, method: str = "GET" + self, + path: str, + headers: dict | None = None, + payload: [dict, str] = None, + files: list = None, + method: str = "GET", ) -> Response: if self.encoded_key is None or self.encoded_key == "": raise APIKeyMissing @@ -36,33 +48,39 @@ def do_request( "accept": "application/json", } url = f"{self.api_url}/{path}" + + def format_headers(headers_dict): + return "\n".join(f"{k}: {v}" for k, v in headers_dict.items()) + try: start_time = time.time() response = requests.request( method.upper(), url, headers=headers, data=payload, files=files, timeout=self.request_timeout ) request_duration = time.time() - start_time - + + headers_str = f"\n\nHeaders:\n{format_headers(response.headers)}" if response.headers else "" + path_str = f"\nPath: {url}" + if response.status_code == 401: - raise APIAccessDenied("Unauthorized") + raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}") if response.status_code == 403: try: - error_message = response.json().get('error', {}).get('message', '') + error_message = response.json().get("error", {}).get("message", "") if "Insufficient permissions for API method" in error_message: - raise APIInsufficientPermissions(error_message) + raise APIInsufficientPermissions(f"{error_message}{path_str}{headers_str}") elif "Organization not allowed" in error_message: - raise APIOrganizationNotAllowed(error_message) + raise APIOrganizationNotAllowed(f"{error_message}{path_str}{headers_str}") elif "Insufficient max quota" in error_message: - raise APIInsufficientQuota(error_message) + raise APIInsufficientQuota(f"{error_message}{path_str}{headers_str}") else: - raise APIAccessDenied(error_message or "Access denied") + raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}") except ValueError: - # If JSON parsing fails - raise APIAccessDenied("Access denied") + raise APIAccessDenied(f"Access denied{path_str}{headers_str}") if response.status_code == 404: - raise APIResourceNotFound(f"Path not found {path}") + raise APIResourceNotFound(f"Path not found {path}{path_str}{headers_str}") if response.status_code == 429: - retry_after = response.headers.get('retry-after') + retry_after = response.headers.get("retry-after") if retry_after: try: seconds = int(retry_after) @@ -73,23 +91,34 @@ def do_request( time_msg = f" Retry after: {retry_after}" else: time_msg = "" - raise APIInsufficientQuota(f"Insufficient quota for API route.{time_msg}") + raise APIInsufficientQuota(f"Insufficient quota for API route.{time_msg}{path_str}{headers_str}") if response.status_code == 502: - raise APIBadGateway("Upstream server error") + raise APIBadGateway(f"Upstream server error{path_str}{headers_str}") if response.status_code >= 400: - raise APIFailure(f"Bad Request: HTTP {response.status_code}") - + raise APIFailure( + f"Bad Request: HTTP original_status_code:{response.status_code}{path_str}{headers_str}", + status_code=500, + ) + return response - + except Timeout: request_duration = time.time() - start_time raise APITimeout(f"Request timed out after {request_duration:.2f} seconds") except ConnectionError as error: request_duration = time.time() - start_time raise APIConnectionError(f"Connection error after {request_duration:.2f} seconds: {error}") - except (APIAccessDenied, APIInsufficientQuota, APIResourceNotFound, APIFailure, - APITimeout, APIConnectionError, APIBadGateway, APIInsufficientPermissions, - APIOrganizationNotAllowed): + except ( + APIAccessDenied, + APIInsufficientQuota, + APIResourceNotFound, + APIFailure, + APITimeout, + APIConnectionError, + APIBadGateway, + APIInsufficientPermissions, + APIOrganizationNotAllowed, + ): # Let all our custom exceptions propagate up unchanged raise except Exception as error: diff --git a/socketdev/dependencies/__init__.py b/socketdev/dependencies/__init__.py index 45ea8c5..201a9e0 100644 --- a/socketdev/dependencies/__init__.py +++ b/socketdev/dependencies/__init__.py @@ -1,8 +1,12 @@ import json from urllib.parse import urlencode - +import logging from socketdev.tools import load_files +log = logging.getLogger("socketdev") + +# TODO: Add types for responses. Not currently used in the CLI. + class Dependencies: def __init__(self, api): @@ -17,8 +21,8 @@ def post(self, files: list, params: dict) -> dict: result = response.json() else: result = {} - print(f"Error posting {files} to the Dependency API") - print(response.text) + log.error(f"Error posting {files} to the Dependency API") + log.error(response.text) return result def get( @@ -34,6 +38,6 @@ def get( result = response.json() else: result = {} - print("Unable to retrieve Dependencies") - print(response.text) + log.error("Unable to retrieve Dependencies") + log.error(response.text) return result diff --git a/socketdev/export/__init__.py b/socketdev/export/__init__.py index d56f886..19c04e1 100644 --- a/socketdev/export/__init__.py +++ b/socketdev/export/__init__.py @@ -1,6 +1,9 @@ from urllib.parse import urlencode from dataclasses import dataclass, asdict from typing import Optional +import logging + +log = logging.getLogger("socketdev") @dataclass @@ -23,40 +26,50 @@ class Export: def __init__(self, api): self.api = api - def cdx_bom(self, org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None) -> dict: + def cdx_bom( + self, org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None, use_types: bool = False + ) -> dict: """ Export a Socket SBOM as a CycloneDX SBOM :param org_slug: String - The slug of the organization :param id: String - The id of either a full scan or an sbom report :param query_params: Optional[ExportQueryParams] - Query parameters for filtering - :return: + :param use_types: Optional[bool] - Whether to return typed responses + :return: dict """ path = f"orgs/{org_slug}/export/cdx/{id}" if query_params: path += query_params.to_query_params() response = self.api.do_request(path=path) - try: - sbom = response.json() - sbom["success"] = True - except Exception as error: - sbom = {"success": False, "message": str(error)} - return sbom - - def spdx_bom(self, org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None) -> dict: + + if response.status_code == 200: + return response.json() + # TODO: Add typed response when types are defined + + log.error(f"Error exporting CDX BOM: {response.status_code}") + print(response.text) + return {} + + def spdx_bom( + self, org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None, use_types: bool = False + ) -> dict: """ Export a Socket SBOM as an SPDX SBOM :param org_slug: String - The slug of the organization :param id: String - The id of either a full scan or an sbom report :param query_params: Optional[ExportQueryParams] - Query parameters for filtering - :return: + :param use_types: Optional[bool] - Whether to return typed responses + :return: dict """ path = f"orgs/{org_slug}/export/spdx/{id}" if query_params: path += query_params.to_query_params() response = self.api.do_request(path=path) - try: - sbom = response.json() - sbom["success"] = True - except Exception as error: - sbom = {"success": False, "message": str(error)} - return sbom + + if response.status_code == 200: + return response.json() + # TODO: Add typed response when types are defined + + log.error(f"Error exporting SPDX BOM: {response.status_code}") + print(response.text) + return {} diff --git a/socketdev/fullscans/__init__.py b/socketdev/fullscans/__init__.py index 68caea9..01685d0 100644 --- a/socketdev/fullscans/__init__.py +++ b/socketdev/fullscans/__init__.py @@ -1,7 +1,7 @@ import json import logging from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass, asdict, field @@ -9,6 +9,7 @@ log = logging.getLogger("socketdev") + class SocketPURL_Type(str, Enum): UNKNOWN = "unknown" NPM = "npm" @@ -31,6 +32,7 @@ class SocketCategory(str, Enum): LICENSE = "license" MISCELLANEOUS = "miscellaneous" + class DiffType(str, Enum): ADDED = "added" REMOVED = "removed" @@ -38,6 +40,7 @@ class DiffType(str, Enum): REPLACED = "replaced" UPDATED = "updated" + @dataclass(kw_only=True) class SocketPURL: type: SocketPURL_Type @@ -47,8 +50,11 @@ class SocketPURL: subpath: Optional[str] = None version: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SocketPURL": @@ -58,25 +64,26 @@ def from_dict(cls, data: dict) -> "SocketPURL": namespace=data.get("namespace"), release=data.get("release"), subpath=data.get("subpath"), - version=data.get("version") + version=data.get("version"), ) + @dataclass class SocketManifestReference: file: str start: Optional[int] = None end: Optional[int] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SocketManifestReference": - return cls( - file=data["file"], - start=data.get("start"), - end=data.get("end") - ) + return cls(file=data["file"], start=data.get("start"), end=data.get("end")) + @dataclass class FullScanParams: @@ -93,8 +100,11 @@ class FullScanParams: set_as_pending_head: Optional[bool] = None tmp: Optional[bool] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "FullScanParams": @@ -111,9 +121,10 @@ def from_dict(cls, data: dict) -> "FullScanParams": integration_org_slug=data.get("integration_org_slug"), make_default_branch=data.get("make_default_branch"), set_as_pending_head=data.get("set_as_pending_head"), - tmp=data.get("tmp") + tmp=data.get("tmp"), ) + @dataclass class FullScanMetadata: id: str @@ -123,15 +134,18 @@ class FullScanMetadata: repository_id: str branch: str html_report_url: str - repo: Optional[str] = None + repo: Optional[str] = None organization_slug: Optional[str] = None committers: Optional[List[str]] = None commit_message: Optional[str] = None commit_hash: Optional[str] = None pull_request: Optional[int] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "FullScanMetadata": @@ -148,9 +162,10 @@ def from_dict(cls, data: dict) -> "FullScanMetadata": committers=data.get("committers"), commit_message=data.get("commit_message"), commit_hash=data.get("commit_hash"), - pull_request=data.get("pull_request") + pull_request=data.get("pull_request"), ) + @dataclass class CreateFullScanResponse: success: bool @@ -158,8 +173,11 @@ class CreateFullScanResponse: data: Optional[FullScanMetadata] = None message: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "CreateFullScanResponse": @@ -167,9 +185,10 @@ def from_dict(cls, data: dict) -> "CreateFullScanResponse": success=data["success"], status=data["status"], message=data.get("message"), - data=FullScanMetadata.from_dict(data.get("data")) if data.get("data") else None + data=FullScanMetadata.from_dict(data.get("data")) if data.get("data") else None, ) + @dataclass class GetFullScanMetadataResponse: success: bool @@ -177,8 +196,11 @@ class GetFullScanMetadataResponse: data: Optional[FullScanMetadata] = None message: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "GetFullScanMetadataResponse": @@ -186,19 +208,23 @@ def from_dict(cls, data: dict) -> "GetFullScanMetadataResponse": success=data["success"], status=data["status"], message=data.get("message"), - data=FullScanMetadata.from_dict(data.get("data")) if data.get("data") else None + data=FullScanMetadata.from_dict(data.get("data")) if data.get("data") else None, ) + @dataclass(kw_only=True) class SocketArtifactLink: topLevelAncestors: List[str] - direct: bool = False + direct: bool = False artifact: Optional[Dict] = None dependencies: Optional[List[str]] = None manifestFiles: Optional[List[SocketManifestReference]] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SocketArtifactLink": @@ -209,7 +235,7 @@ def from_dict(cls, data: dict) -> "SocketArtifactLink": direct=direct_val if isinstance(direct_val, bool) else direct_val.lower() == "true", artifact=data.get("artifact"), dependencies=data.get("dependencies"), - manifestFiles=[SocketManifestReference.from_dict(m) for m in manifest_files] if manifest_files else None + manifestFiles=[SocketManifestReference.from_dict(m) for m in manifest_files] if manifest_files else None, ) @@ -222,8 +248,11 @@ class SocketScore: license: float overall: float - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SocketScore": @@ -233,9 +262,10 @@ def from_dict(cls, data: dict) -> "SocketScore": maintenance=data["maintenance"], vulnerability=data["vulnerability"], license=data["license"], - overall=data["overall"] + overall=data["overall"], ) + @dataclass class SecurityCapabilities: env: bool @@ -245,8 +275,11 @@ class SecurityCapabilities: shell: bool unsafe: bool - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SecurityCapabilities": @@ -256,9 +289,10 @@ def from_dict(cls, data: dict) -> "SecurityCapabilities": fs=data["fs"], net=data["net"], shell=data["shell"], - unsafe=data["unsafe"] + unsafe=data["unsafe"], ) + @dataclass class Alert: key: str @@ -270,8 +304,11 @@ class Alert: action: str actionPolicyIndex: int - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "Alert": @@ -283,23 +320,25 @@ def from_dict(cls, data: dict) -> "Alert": end=data["end"], props=data["props"], action=data["action"], - actionPolicyIndex=data["actionPolicyIndex"] + actionPolicyIndex=data["actionPolicyIndex"], ) + @dataclass class LicenseMatch: licenseId: str licenseExceptionId: str - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "LicenseMatch": - return cls( - licenseId=data["licenseId"], - licenseExceptionId=data["licenseExceptionId"] - ) + return cls(licenseId=data["licenseId"], licenseExceptionId=data["licenseExceptionId"]) + @dataclass class LicenseDetail: @@ -312,8 +351,11 @@ class LicenseDetail: provenance: str spdxDisj: List[List[LicenseMatch]] - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "LicenseDetail": @@ -325,10 +367,10 @@ def from_dict(cls, data: dict) -> "LicenseDetail": match_strength=data["match_strength"], filehash=data["filehash"], provenance=data["provenance"], - spdxDisj=[[LicenseMatch.from_dict(match) for match in group] - for group in data["spdxDisj"]] + spdxDisj=[[LicenseMatch.from_dict(match) for match in group] for group in data["spdxDisj"]], ) + @dataclass class AttributionData: purl: str @@ -336,8 +378,11 @@ class AttributionData: foundInFilepath: Optional[str] = None spdxExpr: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "AttributionData": @@ -345,24 +390,28 @@ def from_dict(cls, data: dict) -> "AttributionData": purl=data["purl"], foundAuthors=data["foundAuthors"], foundInFilepath=data.get("foundInFilepath"), - spdxExpr=data.get("spdxExpr") + spdxExpr=data.get("spdxExpr"), ) + @dataclass class LicenseAttribution: attribText: str attribData: List[AttributionData] - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "LicenseAttribution": return cls( - attribText=data["attribText"], - attribData=[AttributionData.from_dict(item) for item in data["attribData"]] + attribText=data["attribText"], attribData=[AttributionData.from_dict(item) for item in data["attribData"]] ) + @dataclass class SocketAlert: key: str @@ -376,8 +425,11 @@ class SocketAlert: action: Optional[str] = None actionPolicyIndex: Optional[int] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SocketAlert": @@ -391,7 +443,7 @@ def from_dict(cls, data: dict) -> "SocketAlert": end=data.get("end"), props=data.get("props"), action=data.get("action"), - actionPolicyIndex=data.get("actionPolicyIndex") + actionPolicyIndex=data.get("actionPolicyIndex"), ) @@ -401,11 +453,11 @@ class DiffArtifact: id: str type: str name: str - score: SocketScore version: str - alerts: List[SocketAlert] licenseDetails: List[LicenseDetail] + score: Optional[SocketScore] = None author: List[str] = field(default_factory=list) + alerts: List[SocketAlert] = field(default_factory=list) license: Optional[str] = None files: Optional[str] = None capabilities: Optional[SecurityCapabilities] = None @@ -421,21 +473,28 @@ class DiffArtifact: error: Optional[str] = None licenseAttrib: Optional[List[LicenseAttribution]] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "DiffArtifact": base_data = data.get("base") head_data = data.get("head") + + score_data = data.get("score") or data.get("scores") + score = SocketScore.from_dict(score_data) if score_data else None + return cls( diffType=DiffType(data["diffType"]), id=data["id"], type=data["type"], name=data["name"], - score=SocketScore.from_dict(data["score"]), + score=score, version=data["version"], - alerts=[SocketAlert.from_dict(alert) for alert in data["alerts"]], + alerts=[SocketAlert.from_dict(alert) for alert in data.get("alerts", [])], licenseDetails=[LicenseDetail.from_dict(detail) for detail in data["licenseDetails"]], files=data.get("files"), license=data.get("license"), @@ -451,9 +510,12 @@ def from_dict(cls, data: dict) -> "DiffArtifact": author=data.get("author", []), state=data.get("state"), error=data.get("error"), - licenseAttrib=[LicenseAttribution.from_dict(attrib) for attrib in data["licenseAttrib"]] if data.get("licenseAttrib") else None + licenseAttrib=[LicenseAttribution.from_dict(attrib) for attrib in data["licenseAttrib"]] + if data.get("licenseAttrib") + else None, ) + @dataclass class DiffArtifacts: added: List[DiffArtifact] @@ -462,8 +524,11 @@ class DiffArtifacts: replaced: List[DiffArtifact] updated: List[DiffArtifact] - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "DiffArtifacts": @@ -472,9 +537,10 @@ def from_dict(cls, data: dict) -> "DiffArtifacts": removed=[DiffArtifact.from_dict(a) for a in data["removed"]], unchanged=[DiffArtifact.from_dict(a) for a in data["unchanged"]], replaced=[DiffArtifact.from_dict(a) for a in data["replaced"]], - updated=[DiffArtifact.from_dict(a) for a in data["updated"]] + updated=[DiffArtifact.from_dict(a) for a in data["updated"]], ) + @dataclass class CommitInfo: repository_id: str @@ -486,8 +552,11 @@ class CommitInfo: commit_hash: Optional[str] = None pull_request: Optional[int] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "CommitInfo": @@ -499,9 +568,10 @@ def from_dict(cls, data: dict) -> "CommitInfo": committers=data["committers"], commit_message=data.get("commit_message"), commit_hash=data.get("commit_hash"), - pull_request=data.get("pull_request") + pull_request=data.get("pull_request"), ) + @dataclass class FullScanDiffReport: before: CommitInfo @@ -510,8 +580,11 @@ class FullScanDiffReport: artifacts: DiffArtifacts directDependenciesChanged: bool = False - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "FullScanDiffReport": @@ -520,9 +593,10 @@ def from_dict(cls, data: dict) -> "FullScanDiffReport": after=CommitInfo.from_dict(data["after"]), directDependenciesChanged=data.get("directDependenciesChanged", False), diff_report_url=data["diff_report_url"], - artifacts=DiffArtifacts.from_dict(data["artifacts"]) + artifacts=DiffArtifacts.from_dict(data["artifacts"]), ) + @dataclass class StreamDiffResponse: success: bool @@ -530,8 +604,11 @@ class StreamDiffResponse: data: Optional[FullScanDiffReport] = None message: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "StreamDiffResponse": @@ -539,9 +616,10 @@ def from_dict(cls, data: dict) -> "StreamDiffResponse": success=data["success"], status=data["status"], message=data.get("message"), - data=FullScanDiffReport.from_dict(data.get("data")) if data.get("data") else None + data=FullScanDiffReport.from_dict(data.get("data")) if data.get("data") else None, ) + @dataclass(kw_only=True) class SocketArtifact(SocketPURL, SocketArtifactLink): id: str @@ -554,19 +632,22 @@ class SocketArtifact(SocketPURL, SocketArtifactLink): licenseDetails: Optional[List[LicenseDetail]] = field(default_factory=list) size: Optional[int] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SocketArtifact": purl_data = {k: data.get(k) for k in SocketPURL.__dataclass_fields__} link_data = {k: data.get(k) for k in SocketArtifactLink.__dataclass_fields__} - + alerts = data.get("alerts") license_attrib = data.get("licenseAttrib") license_details = data.get("licenseDetails") score = data.get("score") - + return cls( **purl_data, **link_data, @@ -578,9 +659,10 @@ def from_dict(cls, data: dict) -> "SocketArtifact": licenseAttrib=[LicenseAttribution.from_dict(la) for la in license_attrib] if license_attrib else None, licenseDetails=[LicenseDetail.from_dict(ld) for ld in license_details] if license_details else None, score=SocketScore.from_dict(score) if score else None, - size=data.get("size") + size=data.get("size"), ) + @dataclass class FullScanStreamResponse: success: bool @@ -588,8 +670,11 @@ class FullScanStreamResponse: artifacts: Optional[Dict[str, SocketArtifact]] = None message: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "FullScanStreamResponse": @@ -597,12 +682,12 @@ def from_dict(cls, data: dict) -> "FullScanStreamResponse": success=data["success"], status=data["status"], message=data.get("message"), - artifacts={ - k: SocketArtifact.from_dict(v) - for k, v in data["artifacts"].items() - } if data.get("artifacts") else None + artifacts={k: SocketArtifact.from_dict(v) for k, v in data["artifacts"].items()} + if data.get("artifacts") + else None, ) + class FullScans: def __init__(self, api): self.api = api @@ -613,7 +698,6 @@ def create_params_string(self, params: dict) -> str: for name, value in params.items(): if value: if name == "committers" and isinstance(value, list): - for committer in value: param_str += f"&{name}={committer}" else: @@ -623,60 +707,50 @@ def create_params_string(self, params: dict) -> str: return param_str - def get(self, org_slug: str, params: dict) -> GetFullScanMetadataResponse: + def get(self, org_slug: str, params: dict, use_types: bool = False) -> Union[dict, GetFullScanMetadataResponse]: params_arg = self.create_params_string(params) Utils.validate_integration_type(params.get("integration_type", "")) path = "orgs/" + org_slug + "/full-scans" + str(params_arg) - headers = None - payload = None - - response = self.api.do_request(path=path, headers=headers, payload=payload) + response = self.api.do_request(path=path) if response.status_code == 200: result = response.json() - return GetFullScanMetadataResponse.from_dict({ - "success": True, - "status": 200, - "data": result - }) + if use_types: + return GetFullScanMetadataResponse.from_dict({"success": True, "status": 200, "data": result}) + return result error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error getting full scan metadata: {response.status_code}, message: {error_message}") - return GetFullScanMetadataResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) - - def post(self, files: list, params: FullScanParams) -> CreateFullScanResponse: - + if use_types: + return GetFullScanMetadataResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} + + def post(self, files: list, params: FullScanParams, use_types: bool = False) -> Union[dict, CreateFullScanResponse]: org_slug = str(params.org_slug) params_dict = params.to_dict() params_dict.pop("org_slug") - params_arg = self.create_params_string(params_dict) + params_arg = self.create_params_string(params_dict) path = "orgs/" + org_slug + "/full-scans" + str(params_arg) response = self.api.do_request(path=path, method="POST", files=files) - + if response.status_code == 201: result = response.json() - return CreateFullScanResponse.from_dict({ - "success": True, - "status": 201, - "data": result - }) - - log.error(f"Error posting {files} to the Fullscans API") - error_message = response.json().get("error", {}).get("message", "Unknown error") - log.error(error_message) + if use_types: + return CreateFullScanResponse.from_dict({"success": True, "status": 201, "data": result}) + return result - return CreateFullScanResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error posting {files} to the Fullscans API: {response.status_code}, message: {error_message}") + if use_types: + return CreateFullScanResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} def delete(self, org_slug: str, full_scan_id: str) -> dict: path = "orgs/" + org_slug + "/full-scans/" + full_scan_id @@ -691,85 +765,82 @@ def delete(self, org_slug: str, full_scan_id: str) -> dict: log.error(f"Error deleting full scan: {response.status_code}, message: {error_message}") return {} - def stream_diff(self, org_slug: str, before: str, after: str) -> StreamDiffResponse: + def stream_diff( + self, org_slug: str, before: str, after: str, use_types: bool = False + ) -> Union[dict, StreamDiffResponse]: path = f"orgs/{org_slug}/full-scans/diff?before={before}&after={after}" response = self.api.do_request(path=path, method="GET") if response.status_code == 200: - return StreamDiffResponse.from_dict({ - "success": True, - "status": 200, - "data": response.json() - }) + result = response.json() + if use_types: + return StreamDiffResponse.from_dict({"success": True, "status": 200, "data": result}) + return result error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error streaming diff: {response.status_code}, message: {error_message}") - return StreamDiffResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) + if use_types: + return StreamDiffResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} - def stream(self, org_slug: str, full_scan_id: str) -> FullScanStreamResponse: + def stream(self, org_slug: str, full_scan_id: str, use_types: bool = False) -> Union[dict, FullScanStreamResponse]: path = "orgs/" + org_slug + "/full-scans/" + full_scan_id response = self.api.do_request(path=path, method="GET") - + if response.status_code == 200: try: stream_str = [] artifacts = {} result = response.text - result.strip('"') - result.strip() + result = result.strip('"').strip() for line in result.split("\n"): if line != '"' and line != "" and line is not None: item = json.loads(line) stream_str.append(item) for val in stream_str: - artifacts[val["id"]] = val + artifacts[val["id"]] = val + + if use_types: + return FullScanStreamResponse.from_dict({"success": True, "status": 200, "artifacts": artifacts}) + return artifacts - return FullScanStreamResponse.from_dict({ - "success": True, - "status": 200, - "artifacts": artifacts - }) except Exception as e: error_message = f"Error parsing stream response: {str(e)}" log.error(error_message) - return FullScanStreamResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) + if use_types: + return FullScanStreamResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error streaming full scan: {response.status_code}, message: {error_message}") - return FullScanStreamResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) + if use_types: + return FullScanStreamResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} - def metadata(self, org_slug: str, full_scan_id: str) -> GetFullScanMetadataResponse: + def metadata( + self, org_slug: str, full_scan_id: str, use_types: bool = False + ) -> Union[dict, GetFullScanMetadataResponse]: path = "orgs/" + org_slug + "/full-scans/" + full_scan_id + "/metadata" response = self.api.do_request(path=path, method="GET") if response.status_code == 200: - return GetFullScanMetadataResponse.from_dict({ - "success": True, - "status": 200, - "data": response.json() - }) + result = response.json() + if use_types: + return GetFullScanMetadataResponse.from_dict({"success": True, "status": 200, "data": result}) + return result error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error getting metadata: {response.status_code}, message: {error_message}") - return GetFullScanMetadataResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) - - - + if use_types: + return GetFullScanMetadataResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} diff --git a/socketdev/historical/__init__.py b/socketdev/historical/__init__.py new file mode 100644 index 0000000..223d90c --- /dev/null +++ b/socketdev/historical/__init__.py @@ -0,0 +1,47 @@ +import logging +from urllib.parse import urlencode + +log = logging.getLogger("socketdev") + + +class Historical: + def __init__(self, api): + self.api = api + + def list(self, org_slug: str, query_params: dict = None) -> dict: + """Get historical alerts list for an organization. + + Args: + org_slug: Organization slug + query_params: Optional dictionary of query parameters + """ + path = f"orgs/{org_slug}/alerts/historical" + if query_params: + path += "?" + urlencode(query_params) + + response = self.api.do_request(path=path) + if response.status_code == 200: + return response.json() + + log.error(f"Error getting historical alerts: {response.status_code}") + log.error(response.text) + return {} + + def trend(self, org_slug: str, query_params: dict = None) -> dict: + """Get historical alerts trend data for an organization. + + Args: + org_slug: Organization slug + query_params: Optional dictionary of query parameters + """ + path = f"orgs/{org_slug}/alerts/historical/trend" + if query_params: + path += "?" + urlencode(query_params) + + response = self.api.do_request(path=path) + if response.status_code == 200: + return response.json() + + log.error(f"Error getting historical trend: {response.status_code}") + log.error(response.text) + return {} diff --git a/socketdev/npm/__init__.py b/socketdev/npm/__init__.py index a54a6ba..55abd90 100644 --- a/socketdev/npm/__init__.py +++ b/socketdev/npm/__init__.py @@ -1,4 +1,8 @@ +import logging +log = logging.getLogger("socketdev") + +# TODO: Add response type classes for NPM endpoints class NPM: @@ -8,15 +12,17 @@ def __init__(self, api): def issues(self, package: str, version: str) -> list: path = f"npm/{package}/{version}/issues" response = self.api.do_request(path=path) - issues = [] if response.status_code == 200: - issues = response.json() - return issues + return response.json() + log.error(f"Error getting npm issues: {response.status_code}") + print(response.text) + return [] def score(self, package: str, version: str) -> list: path = f"npm/{package}/{version}/score" response = self.api.do_request(path=path) - issues = [] if response.status_code == 200: - issues = response.json() - return issues + return response.json() + log.error(f"Error getting npm score: {response.status_code}") + print(response.text) + return [] diff --git a/socketdev/openapi/__init__.py b/socketdev/openapi/__init__.py index b3df1da..96f7fa9 100644 --- a/socketdev/openapi/__init__.py +++ b/socketdev/openapi/__init__.py @@ -1,4 +1,8 @@ +import logging +log = logging.getLogger("socketdev") + +# TODO: Add response type classes for OpenAPI endpoints class OpenAPI: @@ -9,7 +13,7 @@ def get(self) -> dict: path = "openapi" response = self.api.do_request(path=path) if response.status_code == 200: - openapi = response.json() - else: - openapi = {} - return openapi + return response.json() + log.error(f"Error getting OpenAPI spec: {response.status_code}") + print(response.text) + return {} diff --git a/socketdev/org/__init__.py b/socketdev/org/__init__.py index 12d906c..d59aa0a 100644 --- a/socketdev/org/__init__.py +++ b/socketdev/org/__init__.py @@ -1,4 +1,8 @@ from typing import TypedDict, Dict +import logging + +log = logging.getLogger("socketdev") + class Organization(TypedDict): id: str @@ -7,18 +11,24 @@ class Organization(TypedDict): plan: str slug: str + class OrganizationsResponse(TypedDict): organizations: Dict[str, Organization] # Add other fields from the response if needed + class Orgs: def __init__(self, api): self.api = api - def get(self) -> OrganizationsResponse: + def get(self, use_types: bool = False) -> OrganizationsResponse: path = "organizations" response = self.api.do_request(path=path) if response.status_code == 200: - return response.json() # Return the full response - else: - return {"organizations": {}} # Return an empty structure \ No newline at end of file + result = response.json() + if use_types: + return OrganizationsResponse(result) + return result + log.error(f"Error getting organizations: {response.status_code}") + print(response.text) + return {"organizations": {}} diff --git a/socketdev/purl/__init__.py b/socketdev/purl/__init__.py index 6842e11..248f2ff 100644 --- a/socketdev/purl/__init__.py +++ b/socketdev/purl/__init__.py @@ -5,7 +5,7 @@ class Purl: def __init__(self, api): self.api = api - def post(self, license: str = "true", components: list = []) -> dict: + def post(self, license: str = "true", components: list = []) -> list: path = "purl?" + "license=" + license components = {"components": components} components = json.dumps(components) @@ -13,19 +13,17 @@ def post(self, license: str = "true", components: list = []) -> dict: response = self.api.do_request(path=path, payload=components, method="POST") if response.status_code == 200: purl = [] - purl_dict = {} result = response.text - result.strip('"') - result.strip() + result = result.strip('"').strip() for line in result.split("\n"): - if line != '"' and line != "" and line is not None: - item = json.loads(line) - purl.append(item) - for val in purl: - purl_dict[val["id"]] = val - else: - purl_dict = {} - print(f"Error posting {components} to the Purl API") - print(response.text) + if line and line != '"': + try: + item = json.loads(line) + purl.append(item) + except json.JSONDecodeError: + continue + return purl - return purl_dict + log.error(f"Error posting {components} to the Purl API: {response.status_code}") + print(response.text) + return [] diff --git a/socketdev/quota/__init__.py b/socketdev/quota/__init__.py index 1494888..bd3269f 100644 --- a/socketdev/quota/__init__.py +++ b/socketdev/quota/__init__.py @@ -1,3 +1,9 @@ +import logging + +log = logging.getLogger("socketdev") + +# TODO: Add response type classes for Quota endpoints + class Quota: def __init__(self, api): @@ -7,7 +13,7 @@ def get(self) -> dict: path = "quota" response = self.api.do_request(path=path) if response.status_code == 200: - quota = response.json() - else: - quota = {} - return quota + return response.json() + log.error(f"Error getting quota: {response.status_code}") + print(response.text) + return {} diff --git a/socketdev/report/__init__.py b/socketdev/report/__init__.py index 5483b2a..f92a621 100644 --- a/socketdev/report/__init__.py +++ b/socketdev/report/__init__.py @@ -1,6 +1,10 @@ - +import logging from datetime import datetime, timedelta, timezone +log = logging.getLogger("socketdev") + +# TODO: Add response type classes for Report endpoints + class Report: def __init__(self, api): @@ -21,37 +25,37 @@ def list(self, from_time: int = None) -> dict: path += f"?from={from_time}" response = self.api.do_request(path=path) if response.status_code == 200: - reports = response.json() - else: - reports = {} - return reports + return response.json() + log.error(f"Error listing reports: {response.status_code}") + print(response.text) + return {} def delete(self, report_id: str) -> bool: path = f"report/delete/{report_id}" response = self.api.do_request(path=path, method="DELETE") if response.status_code == 200: - deleted = True - else: - deleted = False - return deleted + return True + log.error(f"Error deleting report: {response.status_code}") + print(response.text) + return False def view(self, report_id) -> dict: path = f"report/view/{report_id}" response = self.api.do_request(path=path) if response.status_code == 200: - report = response.json() - else: - report = {} - return report + return response.json() + log.error(f"Error viewing report: {response.status_code}") + print(response.text) + return {} def supported(self) -> dict: path = "report/supported" response = self.api.do_request(path=path) if response.status_code == 200: - report = response.json() - else: - report = {} - return report + return response.json() + log.error(f"Error getting supported reports: {response.status_code}") + print(response.text) + return {} def create(self, files: list) -> dict: open_files = [] @@ -62,7 +66,7 @@ def create(self, files: list) -> dict: payload = {} response = self.api.do_request(path=path, method="PUT", files=open_files, payload=payload) if response.status_code == 200: - reports = response.json() - else: - reports = {} - return reports + return response.json() + log.error(f"Error creating report: {response.status_code}") + print(response.text) + return {} diff --git a/socketdev/repos/__init__.py b/socketdev/repos/__init__.py index 8328abd..f1028d9 100644 --- a/socketdev/repos/__init__.py +++ b/socketdev/repos/__init__.py @@ -1,10 +1,11 @@ import json import logging -from typing import List, Optional +from typing import Optional, Union from dataclasses import dataclass, asdict log = logging.getLogger("socketdev") + @dataclass class RepositoryInfo: id: str @@ -19,8 +20,11 @@ class RepositoryInfo: default_branch: str slug: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "RepositoryInfo": @@ -35,9 +39,10 @@ def from_dict(cls, data: dict) -> "RepositoryInfo": visibility=data["visibility"], archived=data["archived"], default_branch=data["default_branch"], - slug=data.get("slug") + slug=data.get("slug"), ) + @dataclass class GetRepoResponse: success: bool @@ -45,8 +50,11 @@ class GetRepoResponse: data: Optional[RepositoryInfo] = None message: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "GetRepoResponse": @@ -54,70 +62,65 @@ def from_dict(cls, data: dict) -> "GetRepoResponse": success=data["success"], status=data["status"], message=data.get("message"), - data=RepositoryInfo.from_dict(data.get("data")) if data.get("data") else None + data=RepositoryInfo.from_dict(data.get("data")) if data.get("data") else None, ) + class Repos: def __init__(self, api): self.api = api - def get(self, org_slug: str, **kwargs) -> dict[str, List[RepositoryInfo]]: - query_params = {} - if kwargs: - for key, val in kwargs.items(): - query_params[key] = val - if len(query_params) == 0: - return {} - + def get(self, org_slug: str, **kwargs) -> dict[str, list[dict] | int]: + query_params = kwargs path = "orgs/" + org_slug + "/repos" - if query_params is not None: + + if query_params: # Only add query string if we have parameters path += "?" for param in query_params: value = query_params[param] path += f"{param}={value}&" path = path.rstrip("&") - + response = self.api.do_request(path=path) - + if response.status_code == 200: raw_result = response.json() - result = { - key: [RepositoryInfo.from_dict(repo) for repo in repos] - for key, repos in raw_result.items() - } - return result + per_page = int(query_params.get("per_page", 30)) + + # TEMPORARY: Handle pagination edge case where API returns nextPage=1 even when no more results exist + if raw_result["nextPage"] != 0 and len(raw_result["results"]) < per_page: + raw_result["nextPage"] = 0 + + return raw_result error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error getting repositories: {response.status_code}, message: {error_message}") return {} - def repo(self, org_slug: str, repo_name: str) -> GetRepoResponse: + def repo(self, org_slug: str, repo_name: str, use_types: bool = False) -> Union[dict, GetRepoResponse]: path = f"orgs/{org_slug}/repos/{repo_name}" response = self.api.do_request(path=path) - + if response.status_code == 200: result = response.json() - return GetRepoResponse.from_dict({ - "success": True, - "status": 200, - "data": result - }) - + if use_types: + return GetRepoResponse.from_dict({"success": True, "status": 200, "data": result}) + return result + error_message = response.json().get("error", {}).get("message", "Unknown error") - log.error(f"Failed to get repository: {response.status_code}, message: {error_message}") - return GetRepoResponse.from_dict({ - "success": False, - "status": response.status_code, - "message": error_message - }) + print(f"Failed to get repository: {response.status_code}, message: {error_message}") + if use_types: + return GetRepoResponse.from_dict( + {"success": False, "status": response.status_code, "message": error_message} + ) + return {} def delete(self, org_slug: str, name: str) -> dict: path = f"orgs/{org_slug}/repos/{name}" response = self.api.do_request(path=path, method="DELETE") - + if response.status_code == 200: - result = response.json() - return result + return response.json() error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error deleting repository: {response.status_code}, message: {error_message}") @@ -130,14 +133,13 @@ def post(self, org_slug: str, **kwargs) -> dict: params[key] = val if len(params) == 0: return {} - + path = "orgs/" + org_slug + "/repos" payload = json.dumps(params) response = self.api.do_request(path=path, method="POST", payload=payload) - + if response.status_code == 201: - result = response.json() - return result + return response.json() error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error creating repository: {response.status_code}, message: {error_message}") @@ -150,14 +152,13 @@ def update(self, org_slug: str, repo_name: str, **kwargs) -> dict: params[key] = val if len(params) == 0: return {} - + path = f"orgs/{org_slug}/repos/{repo_name}" payload = json.dumps(params) response = self.api.do_request(path=path, method="POST", payload=payload) - + if response.status_code == 200: - result = response.json() - return result + return response.json() error_message = response.json().get("error", {}).get("message", "Unknown error") log.error(f"Error updating repository: {response.status_code}, message: {error_message}") diff --git a/socketdev/repositories/__init__.py b/socketdev/repositories/__init__.py index f1eaaa6..4390313 100644 --- a/socketdev/repositories/__init__.py +++ b/socketdev/repositories/__init__.py @@ -1,4 +1,7 @@ -from typing import TypedDict +from typing import TypedDict, Union +import logging + +log = logging.getLogger("socketdev") class Repo(TypedDict): @@ -14,11 +17,15 @@ class Repositories: def __init__(self, api): self.api = api - def list(self) -> dict: + def list(self, use_types: bool = False) -> Union[dict, list[Repo]]: path = "repos" response = self.api.do_request(path=path) if response.status_code == 200: - repos = response.json() - else: - repos = {} - return repos + result = response.json() + if use_types: + return [Repo(repo) for repo in result] + return result + + log.error(f"Error listing repositories: {response.status_code}") + print(response.text) + return [] diff --git a/socketdev/sbom/__init__.py b/socketdev/sbom/__init__.py index 4752e54..67aa15e 100644 --- a/socketdev/sbom/__init__.py +++ b/socketdev/sbom/__init__.py @@ -1,11 +1,20 @@ import json from socketdev.core.classes import Package +import logging + +log = logging.getLogger("socketdev") + +# TODO: Add response type classes for SBOM endpoints class Sbom: def __init__(self, api): self.api = api + # NOTE: This method's NDJSON handling is inconsistent with other methods in the SDK. + # While other methods return arrays for NDJSON responses, this returns a dictionary. + # This inconsistency is preserved to maintain backward compatibility with clients + # who have been using this method since its introduction 9 months ago. def view(self, report_id: str) -> dict[str, dict]: path = f"sbom/view/{report_id}" response = self.api.do_request(path=path) @@ -22,6 +31,8 @@ def view(self, report_id: str) -> dict[str, dict]: for val in sbom: sbom_dict[val["id"]] = val else: + log.error(f"Error viewing SBOM: {response.status_code}") + print(response.text) sbom_dict = {} return sbom_dict diff --git a/socketdev/settings/__init__.py b/socketdev/settings/__init__.py index 733c7e9..416454c 100644 --- a/socketdev/settings/__init__.py +++ b/socketdev/settings/__init__.py @@ -1,28 +1,33 @@ import logging from enum import Enum -from typing import Dict, Optional +from typing import Dict, Optional, Union from dataclasses import dataclass, asdict + log = logging.getLogger("socketdev") + class SecurityAction(str, Enum): - DEFER = 'defer' - ERROR = 'error' - WARN = 'warn' - MONITOR = 'monitor' - IGNORE = 'ignore' + DEFER = "defer" + ERROR = "error" + WARN = "warn" + MONITOR = "monitor" + IGNORE = "ignore" + @dataclass class SecurityPolicyRule: action: SecurityAction - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "SecurityPolicyRule": - return cls( - action=SecurityAction(data["action"]) - ) + return cls(action=SecurityAction(data["action"])) + @dataclass class OrgSecurityPolicyResponse: @@ -31,21 +36,24 @@ class OrgSecurityPolicyResponse: securityPolicyRules: Optional[Dict[str, SecurityPolicyRule]] = None message: Optional[str] = None - def __getitem__(self, key): return getattr(self, key) - def to_dict(self): return asdict(self) + def __getitem__(self, key): + return getattr(self, key) + + def to_dict(self): + return asdict(self) @classmethod def from_dict(cls, data: dict) -> "OrgSecurityPolicyResponse": return cls( - securityPolicyRules={ - k: SecurityPolicyRule.from_dict(v) - for k, v in data["securityPolicyRules"].items() - } if data.get("securityPolicyRules") else None, + securityPolicyRules={k: SecurityPolicyRule.from_dict(v) for k, v in data["securityPolicyRules"].items()} + if data.get("securityPolicyRules") + else None, success=data["success"], status=data["status"], - message=data.get("message") + message=data.get("message"), ) + class Settings: def __init__(self, api): self.api = api @@ -63,10 +71,11 @@ def create_params_string(self, params: dict) -> str: param_str += f"&{name}={value}" param_str = "?" + param_str.lstrip("&") - return param_str - def get(self, org_slug: str, custom_rules_only: bool = False) -> OrgSecurityPolicyResponse: + def get( + self, org_slug: str, custom_rules_only: bool = False, use_types: bool = False + ) -> Union[dict, OrgSecurityPolicyResponse]: path = f"orgs/{org_slug}/settings/security-policy" params = {"custom_rules_only": custom_rules_only} params_args = self.create_params_string(params) if custom_rules_only else "" @@ -75,17 +84,89 @@ def get(self, org_slug: str, custom_rules_only: bool = False) -> OrgSecurityPoli if response.status_code == 200: rules = response.json() - return OrgSecurityPolicyResponse.from_dict({ - "securityPolicyRules": rules.get("securityPolicyRules", {}), - "success": True, - "status": 200 - }) - else: - error_message = response.json().get("error", {}).get("message", "Unknown error") - log.error(f"Failed to get security policy: {response.status_code}, message: {error_message}") - return OrgSecurityPolicyResponse.from_dict({ - "securityPolicyRules": {}, - "success": False, - "status": response.status_code, - "message": error_message - }) + if use_types: + return OrgSecurityPolicyResponse.from_dict( + {"securityPolicyRules": rules.get("securityPolicyRules", {}), "success": True, "status": 200} + ) + return rules + + error_message = response.json().get("error", {}).get("message", "Unknown error") + print(f"Failed to get security policy: {response.status_code}, message: {error_message}") + if use_types: + return OrgSecurityPolicyResponse.from_dict( + {"securityPolicyRules": {}, "success": False, "status": response.status_code, "message": error_message} + ) + return {} + + def integration_events(self, org_slug: str, integration_id: str) -> dict: + """Get integration events for a specific integration. + + Args: + org_slug: Organization slug + integration_id: Integration ID + """ + path = f"orgs/{org_slug}/settings/integrations/{integration_id}" + response = self.api.do_request(path=path) + + if response.status_code == 200: + return response.json() + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error getting integration events: {response.status_code}, message: {error_message}") + return {} + + def get_license_policy(self, org_slug: str) -> dict: + """Get license policy settings for an organization. + + Args: + org_slug: Organization slug + """ + path = f"orgs/{org_slug}/settings/license-policy" + response = self.api.do_request(path=path) + + if response.status_code == 200: + return response.json() + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error getting license policy: {response.status_code}, message: {error_message}") + return {} + + def update_security_policy(self, org_slug: str, body: dict, custom_rules_only: bool = False) -> dict: + """Update security policy settings for an organization. + + Args: + org_slug: Organization slug + body: Security policy configuration to update + custom_rules_only: Optional flag to update only custom rules + """ + path = f"orgs/{org_slug}/settings/security-policy" + if custom_rules_only: + path += "?custom_rules_only=true" + + response = self.api.do_request(path=path, method="POST", payload=body) + + if response.status_code == 200: + return response.json() + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error updating security policy: {response.status_code}, message: {error_message}") + return {} + + def update_license_policy(self, org_slug: str, body: dict, merge_update: bool = False) -> dict: + """Update license policy settings for an organization. + + Args: + org_slug: Organization slug + body: License policy configuration to update + merge_update: Optional flag to merge updates instead of replacing (defaults to False) + """ + path = f"orgs/{org_slug}/settings/license-policy?merge_update={str(merge_update).lower()}" + + response = self.api.do_request(path=path, method="POST", payload=body) + + if response.status_code == 200: + return response.json() + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error updating license policy: {response.status_code}, message: {error_message}") + return {} diff --git a/socketdev/triage/__init__.py b/socketdev/triage/__init__.py new file mode 100644 index 0000000..b3a9d96 --- /dev/null +++ b/socketdev/triage/__init__.py @@ -0,0 +1,47 @@ +import logging +from urllib.parse import urlencode + +log = logging.getLogger("socketdev") + + +class Triage: + def __init__(self, api): + self.api = api + + def list_alert_triage(self, org_slug: str, query_params: dict = None) -> dict: + """Get list of triaged alerts for an organization. + + Args: + org_slug: Organization slug + query_params: Optional dictionary of query parameters + """ + path = f"orgs/{org_slug}/triage/alerts" + if query_params: + path += "?" + urlencode(query_params) + + response = self.api.do_request(path=path) + + if response.status_code == 200: + return response.json() + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error getting alert triage list: {response.status_code}, message: {error_message}") + return {} + + def update_alert_triage(self, org_slug: str, body: dict) -> dict: + """Update triaged alerts for an organization. + + Args: + org_slug: Organization slug + body: Alert triage configuration to update + """ + path = f"orgs/{org_slug}/triage/alerts" + + response = self.api.do_request(path=path, method="POST", payload=body) + + if 200 <= response.status_code < 300: + return response.json() + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error updating alert triage: {response.status_code}, message: {error_message}") + return {} diff --git a/socketdev/version.py b/socketdev/version.py index 6ff6d62..8cb37b5 100644 --- a/socketdev/version.py +++ b/socketdev/version.py @@ -1 +1 @@ -__version__ = "2.0.7" \ No newline at end of file +__version__ = "2.0.8"