Skip to content

Commit 07e0244

Browse files
committed
optimized creating new diff
1 parent 3ef84ac commit 07e0244

File tree

2 files changed

+101
-82
lines changed

2 files changed

+101
-82
lines changed

socketsecurity/core/__init__.py

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -216,30 +216,64 @@ def load_files_for_sending(files: List[str], workspace: str) -> List[Tuple[str,
216216

217217
return send_files
218218

219-
def create_full_scan(self, files: List[str], params: FullScanParams) -> FullScan:
220-
"""
221-
Creates a new full scan via the Socket API.
222-
223-
Args:
224-
files: List of files to scan
225-
params: Parameters for the full scan
226-
227-
Returns:
228-
FullScan object with scan results
229-
"""
219+
def create_full_scan(self, files: List[str], params: FullScanParams, store_results: bool = True) -> FullScan:
220+
"""Creates a new full scan via the Socket API."""
230221
log.debug("Creating new full scan")
231222
create_full_start = time.time()
232223

224+
# Time the post API call
225+
post_start = time.time()
233226
res = self.sdk.fullscans.post(files, params)
227+
post_end = time.time()
228+
log.debug(f"API fullscans.post took {post_end - post_start:.2f} seconds")
229+
234230
if not res.success:
235231
log.error(f"Error creating full scan: {res.message}, status: {res.status}")
236232
raise Exception(f"Error creating full scan: {res.message}, status: {res.status}")
237233

238234
full_scan = FullScan(**asdict(res.data))
235+
236+
if not store_results:
237+
full_scan.sbom_artifacts = []
238+
full_scan.packages = {}
239+
return full_scan
239240

240-
full_scan_artifacts_dict = self.get_sbom_data(full_scan.id)
241-
full_scan.sbom_artifacts = self.get_sbom_data_list(full_scan_artifacts_dict)
242-
full_scan.packages = self.create_packages_dict(full_scan.sbom_artifacts)
241+
# Time the stream API call
242+
stream_start = time.time()
243+
artifacts_response = self.sdk.fullscans.stream(self.config.org_slug, full_scan.id)
244+
stream_end = time.time()
245+
log.debug(f"API fullscans.stream took {stream_end - stream_start:.2f} seconds")
246+
247+
if not artifacts_response.success:
248+
log.error(f"Failed to get SBOM data for full-scan {full_scan.id}")
249+
log.error(artifacts_response.message)
250+
full_scan.sbom_artifacts = []
251+
full_scan.packages = {}
252+
return full_scan
253+
254+
# Store the original SocketArtifact objects
255+
full_scan.sbom_artifacts = list(artifacts_response.artifacts.values())
256+
257+
# Create packages dictionary directly from the artifacts
258+
packages = {}
259+
top_level_count = {}
260+
261+
for artifact in artifacts_response.artifacts.values():
262+
package = Package.from_socket_artifact(artifact)
263+
if package.id not in packages:
264+
package.license_text = self.get_package_license_text(package)
265+
packages[package.id] = package
266+
267+
# Count top-level ancestors in the same pass
268+
if package.topLevelAncestors:
269+
for top_id in package.topLevelAncestors:
270+
top_level_count[top_id] = top_level_count.get(top_id, 0) + 1
271+
272+
# Update transitive counts
273+
for package in packages.values():
274+
package.transitives = top_level_count.get(package.id, 0)
275+
276+
full_scan.packages = packages
243277

244278
create_full_end = time.time()
245279
total_time = create_full_end - create_full_start
@@ -351,22 +385,18 @@ def get_head_scan_for_repo(self, repo_slug: str) -> str:
351385
return repo_info.head_full_scan_id if repo_info.head_full_scan_id else None
352386

353387
def get_added_and_removed_packages(self, head_full_scan_id: Optional[str], new_full_scan: FullScan) -> Tuple[Dict[str, Package], Dict[str, Package]]:
354-
"""
355-
Get packages that were added and removed between scans.
356-
357-
Args:
358-
head_full_scan: Previous scan (may be None if first scan)
359-
new_full_scan: New scan just created
360-
361-
Returns:
362-
Tuple of (added_packages, removed_packages) dictionaries
363-
"""
388+
"""Get packages that were added and removed between scans."""
364389
if head_full_scan_id is None:
365390
log.info(f"No head scan found. New scan ID: {new_full_scan.id}")
366391
return new_full_scan.packages, {}
367392

368393
log.info(f"Comparing scans - Head scan ID: {head_full_scan_id}, New scan ID: {new_full_scan.id}")
394+
395+
# Time the stream_diff API call
396+
diff_start = time.time()
369397
diff_report = self.sdk.fullscans.stream_diff(self.config.org_slug, head_full_scan_id, new_full_scan.id).data
398+
diff_end = time.time()
399+
log.debug(f"API fullscans.stream_diff took {diff_end - diff_start:.2f} seconds")
370400

371401
log.info(f"Diff report artifact counts:")
372402
log.info(f"Added: {len(diff_report.artifacts.added)}")
@@ -383,12 +413,12 @@ def get_added_and_removed_packages(self, head_full_scan_id: Optional[str], new_f
383413

384414
for artifact in added_artifacts:
385415
try:
386-
pkg = Package.from_diff_artifact(asdict(artifact))
416+
pkg = Package.from_diff_artifact(artifact)
387417
added_packages[artifact.id] = pkg
388418
except KeyError:
389419
log.error(f"KeyError: Could not create package from added artifact {artifact.id}")
390420
log.error(f"Artifact details - name: {artifact.name}, version: {artifact.version}")
391-
matches = [p for p in new_full_scan.packages.values() if p.name == artifact.name and p.version == artifact.version]
421+
matches = [p for p in added_artifacts.values() if p.name == artifact.name and p.version == artifact.version]
392422
if matches:
393423
log.error(f"Found {len(matches)} packages with matching name/version:")
394424
for m in matches:
@@ -403,7 +433,7 @@ def get_added_and_removed_packages(self, head_full_scan_id: Optional[str], new_f
403433
except KeyError:
404434
log.error(f"KeyError: Could not create package from removed artifact {artifact.id}")
405435
log.error(f"Artifact details - name: {artifact.name}, version: {artifact.version}")
406-
matches = [p for p in head_full_scan.packages.values() if p.name == artifact.name and p.version == artifact.version]
436+
matches = [p for p in removed_artifacts.values() if p.name == artifact.name and p.version == artifact.version]
407437
if matches:
408438
log.error(f"Found {len(matches)} packages with matching name/version:")
409439
for m in matches:
@@ -419,14 +449,7 @@ def create_new_diff(
419449
params: FullScanParams,
420450
no_change: bool = False
421451
) -> Diff:
422-
"""Create a new diff using the Socket SDK.
423-
424-
Args:
425-
path: Path to look for manifest files
426-
params: Query params for the Full Scan endpoint
427-
428-
no_change: If True, return empty diff
429-
"""
452+
"""Create a new diff using the Socket SDK."""
430453
print(f"starting create_new_diff with no_change: {no_change}")
431454
if no_change:
432455
return Diff(id="no_diff_id")
@@ -439,20 +462,16 @@ def create_new_diff(
439462
if not files:
440463
return Diff(id="no_diff_id")
441464

465+
# Initialize head scan ID
442466
head_full_scan_id = None
443-
444467
try:
445468
# Get head scan ID
446469
head_full_scan_id = self.get_head_scan_for_repo(params.repo)
447470
except APIResourceNotFound:
448-
head_full_scan_id = None
449-
450-
# Create new scan
451-
new_scan_start = time.time()
452-
new_full_scan = self.create_full_scan(files_for_sending, params)
453-
new_scan_end = time.time()
454-
log.info(f"Total time to create new full scan: {new_scan_end - new_scan_start:.2f}")
471+
pass
455472

473+
# Create new scan - only store results if we don't have a head scan to diff against
474+
new_full_scan = self.create_full_scan(files_for_sending, params, store_results=head_full_scan_id is None)
456475

457476
added_packages, removed_packages = self.get_added_and_removed_packages(head_full_scan_id, new_full_scan)
458477

socketsecurity/core/classes.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass, field
33
from typing import Dict, List, TypedDict, Any, Optional
44

5-
from socketdev.fullscans import FullScanMetadata, SocketArtifact, SocketArtifactLink, DiffType, SocketManifestReference, SocketScore, SocketAlert
5+
from socketdev.fullscans import FullScanMetadata, SocketArtifact, SocketArtifactLink, SocketScore, SocketAlert, DiffArtifact
66

77
__all__ = [
88
"Report",
@@ -123,40 +123,40 @@ class Package(SocketArtifactLink):
123123
url: str = ""
124124

125125
@classmethod
126-
def from_socket_artifact(cls, data: dict) -> "Package":
126+
def from_socket_artifact(cls, artifact: SocketArtifact) -> "Package":
127127
"""
128-
Create a Package from a SocketArtifact dictionary.
128+
Create a Package from a SocketArtifact.
129129
130130
Args:
131-
data: Dictionary containing SocketArtifact data
131+
artifact: SocketArtifact instance from scan results
132132
133133
Returns:
134134
New Package instance
135135
"""
136136
return cls(
137-
id=data["id"],
138-
name=data["name"],
139-
version=data["version"],
140-
type=data["type"],
141-
score=data["score"],
142-
alerts=data["alerts"],
143-
author=data.get("author", []),
144-
size=data.get("size"),
145-
license=data.get("license"),
146-
topLevelAncestors=data["topLevelAncestors"],
147-
direct=data.get("direct", False),
148-
manifestFiles=data.get("manifestFiles", []),
149-
dependencies=data.get("dependencies"),
150-
artifact=data.get("artifact")
137+
id=artifact.id,
138+
name=artifact.name,
139+
version=artifact.version,
140+
type=artifact.type,
141+
score=artifact.score,
142+
alerts=artifact.alerts,
143+
author=artifact.author or [],
144+
size=artifact.size,
145+
license=artifact.license,
146+
topLevelAncestors=artifact.topLevelAncestors,
147+
direct=artifact.direct,
148+
manifestFiles=artifact.manifestFiles,
149+
dependencies=artifact.dependencies,
150+
artifact=artifact.artifact
151151
)
152152

153153
@classmethod
154-
def from_diff_artifact(cls, data: dict) -> "Package":
154+
def from_diff_artifact(cls, artifact: DiffArtifact) -> "Package":
155155
"""
156-
Create a Package from a DiffArtifact dictionary.
156+
Create a Package from a DiffArtifact.
157157
158158
Args:
159-
data: Dictionary containing DiffArtifact data
159+
artifact: DiffArtifact instance from diff results
160160
161161
Returns:
162162
New Package instance
@@ -165,29 +165,29 @@ def from_diff_artifact(cls, data: dict) -> "Package":
165165
ValueError: If reference data cannot be found in DiffArtifact
166166
"""
167167
ref = None
168-
if data["diffType"] in ["added", "updated"] and data.get("head"):
169-
ref = data["head"][0]
170-
elif data["diffType"] in ["removed", "replaced"] and data.get("base"):
171-
ref = data["base"][0]
168+
if artifact.diffType in ["added", "updated"] and artifact.head:
169+
ref = artifact.head[0]
170+
elif artifact.diffType in ["removed", "replaced"] and artifact.base:
171+
ref = artifact.base[0]
172172

173173
if not ref:
174174
raise ValueError("Could not find reference data in DiffArtifact")
175175

176176
return cls(
177-
id=data["id"],
178-
name=data["name"],
179-
version=data["version"],
180-
type=data["type"],
181-
score=data["score"],
182-
alerts=data["alerts"],
183-
author=data.get("author", []),
184-
size=data.get("size"),
185-
license=data.get("license"),
186-
topLevelAncestors=ref["topLevelAncestors"],
187-
direct=ref.get("direct", False),
188-
manifestFiles=ref.get("manifestFiles", []),
189-
dependencies=ref.get("dependencies"),
190-
artifact=ref.get("artifact")
177+
id=artifact.id,
178+
name=artifact.name,
179+
version=artifact.version,
180+
type=artifact.type,
181+
score=artifact.score,
182+
alerts=artifact.alerts,
183+
author=artifact.author or [],
184+
size=artifact.size,
185+
license=artifact.license,
186+
topLevelAncestors=ref.topLevelAncestors,
187+
direct=ref.direct,
188+
manifestFiles=ref.manifestFiles,
189+
dependencies=ref.dependencies,
190+
artifact=ref.artifact
191191
)
192192

193193
class Issue:

0 commit comments

Comments
 (0)