Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,10 @@ public Vulnerability getVulnerabilityByVulnId(Vulnerability.Source source, Strin
return getVulnerabilityQueryManager().getVulnerabilityByVulnId(source, vulnId, includeVulnerableSoftware);
}

public List<Vulnerability> getVulnerabilitiesBySourceAndVulnIds(Collection<VulnIdAndSource> vulnIdsAndSources) {
return getVulnerabilityQueryManager().getVulnerabilitiesBySourceAndVulnIds(vulnIdsAndSources);
}

public void addVulnerability(Vulnerability vulnerability, Component component, AnalyzerIdentity analyzerIdentity) {
getVulnerabilityQueryManager().addVulnerability(vulnerability, component, analyzerIdentity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -168,8 +169,8 @@ public Vulnerability synchronizeVulnerability(Vulnerability vulnerability, boole
} else {
// Update only if changes are detected
return hasChanges(existingVulnerability, vulnerability)
? updateVulnerability(vulnerability, commitIndex)
: null;
? updateVulnerability(vulnerability, commitIndex)
: null;
}
});
}
Expand Down Expand Up @@ -383,6 +384,62 @@ public PaginatedResult getVulnerabilities(Component component, boolean includeSu
return result;
}

/**
* Returns vulnerabilities by source and vulnerability IDs.
* Efficiently fetches multiple vulnerabilities based on specific source/ID pairs.
* @param vulnIdsAndSources collection of {@link VulnIdAndSource} records to fetch
* @return list of matching Vulnerability objects
*/
public List<Vulnerability> getVulnerabilitiesBySourceAndVulnIds(
Collection<VulnIdAndSource> vulnIdsAndSources) {

if (vulnIdsAndSources == null || vulnIdsAndSources.isEmpty()) {
return Collections.emptyList();
}

Map<String, Set<String>> sourceToVulnIds = vulnIdsAndSources.stream()
.collect(Collectors.groupingBy(
v -> v.source().name(),
Collectors.mapping(VulnIdAndSource::vulnId, Collectors.toSet())
));

final Query<Vulnerability> query = pm.newQuery(Vulnerability.class);
query.getFetchPlan().addGroup(Vulnerability.FetchGroup.COMPONENTS.name());

StringBuilder filter = new StringBuilder();
Map<String, Object> paramValues = new HashMap<>();

int idx = 0;
for (Map.Entry<String, Set<String>> entry : sourceToVulnIds.entrySet()) {
if (!filter.isEmpty()) {
filter.append(" || ");
}

String sourceParam = "source" + idx;
String idsParam = "vulnId" + idx;

// Constructs: (source == :source0 && :vulnId0.contains(vulnId))
filter.append("(source == :").append(sourceParam)
.append(" && :").append(idsParam).append(".contains(vulnId))");

paramValues.put(sourceParam, entry.getKey());
paramValues.put(idsParam, entry.getValue());

idx++;
}

query.setFilter(filter.toString());
query.setNamedParameters(paramValues);
List<Vulnerability> fetched = executeAndCloseList(query);

if (fetched == null || fetched.isEmpty()) {
return Collections.emptyList();
}

return fetched;
}


/**
* Returns a List of Vulnerability for the specified Component and excludes suppressed vulnerabilities.
* This method if designed NOT to provide paginated results.
Expand Down Expand Up @@ -680,7 +737,7 @@ public List<VulnerabilityAlias> getVulnerabilityAliases(Vulnerability vulnerabil
} else {
query = pm.newQuery(VulnerabilityAlias.class, "internalId == :internalId");
}
return (List<VulnerabilityAlias>)query.execute(vulnerability.getVulnId());
return (List<VulnerabilityAlias>)query.execute(vulnerability.getVulnId());
}


Expand Down
222 changes: 185 additions & 37 deletions src/main/java/org/dependencytrack/tasks/scanners/TrivyAnalysisTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.dependencytrack.model.Component;
import org.dependencytrack.model.ComponentProperty;
import org.dependencytrack.model.ConfigPropertyConstants;
import org.dependencytrack.model.VulnIdAndSource;
import org.dependencytrack.model.Vulnerability;
import org.dependencytrack.model.VulnerabilityAnalysisLevel;
import org.dependencytrack.parser.trivy.TrivyParser;
Expand All @@ -71,9 +72,14 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNullElseGet;
import static org.dependencytrack.common.ConfigKey.TRIVY_RETRY_BACKOFF_INITIAL_DURATION_MS;
Expand Down Expand Up @@ -360,11 +366,186 @@ private void handleResults(final Map<String, Component> componentByPurl, final A
}
}

for (final Map.Entry<Component, List<trivy.proto.common.Vulnerability>> entry : vulnsByComponent.entrySet()) {
final Component component = entry.getKey();
final List<trivy.proto.common.Vulnerability> vulns = entry.getValue();
handle(component, vulns);
handleAllVulnsByComponent(vulnsByComponent);
}

private void handleAllVulnsByComponent(final Map<Component, List<trivy.proto.common.Vulnerability>> vulnsByComponent) {
if (vulnsByComponent.isEmpty()) {
return;
}

try (var qm = new QueryManager()) {
final var trivyParser = new TrivyParser();

final Map<UUID, Component> persistentComponents =
resolvePersistentComponents(qm, vulnsByComponent.keySet());

final Map<trivy.proto.common.Vulnerability, Vulnerability> parsedCache =
parseAllVulnerabilities(trivyParser, vulnsByComponent.values());

final Map<VulnIdAndSource, Vulnerability> existingVulnerabilities =
fetchExistingVulnerabilities(qm, parsedCache.values());

final Set<VulnerabilityAddition> additions =
processComponents(
qm,
vulnsByComponent,
persistentComponents,
parsedCache,
existingVulnerabilities
);

finalizeAdditions(qm, additions);

qm.getPersistenceManager().evictAll();

boolean hasNewVulnerabilities = hasNewVulnerabilities(additions, existingVulnerabilities);

parsedCache.clear();
existingVulnerabilities.clear();
additions.clear();
persistentComponents.clear();

if (hasNewVulnerabilities) {
Event.dispatch(new IndexEvent(IndexEvent.Action.COMMIT, Vulnerability.class));
}
}
}

private Map<UUID, Component> resolvePersistentComponents(
QueryManager qm,
Set<Component> components
) {
Set<UUID> uuids = components.stream()
.map(Component::getUuid)
.collect(Collectors.toSet());

if (uuids.isEmpty()) {
return Map.of();
}

List<UUID> uuidList = new ArrayList<>(uuids);

List<Component> persistentComponents =
qm.getObjectsByUuids(Component.class, uuidList);

return persistentComponents.stream()
.collect(Collectors.toMap(
Component::getUuid,
Function.identity()
));
}

private Map<trivy.proto.common.Vulnerability, Vulnerability> parseAllVulnerabilities(
TrivyParser parser,
Collection<List<trivy.proto.common.Vulnerability>> trivyLists
) {
Map<trivy.proto.common.Vulnerability, Vulnerability> cache = new HashMap<>();

for (List<trivy.proto.common.Vulnerability> list : trivyLists) {
for (trivy.proto.common.Vulnerability trivyVuln : list) {
cache.computeIfAbsent(trivyVuln, parser::parse);
}
}
return cache;
}


private Map<VulnIdAndSource, Vulnerability> fetchExistingVulnerabilities(
QueryManager qm,
Collection<Vulnerability> parsedVulns
) {
Set<VulnIdAndSource> identifiers = parsedVulns.stream()
.map(v -> new VulnIdAndSource(v.getVulnId(), v.getSource()))
.collect(Collectors.toSet());

Map<VulnIdAndSource, Vulnerability> result = new HashMap<>();

for (Vulnerability vuln : qm.getVulnerabilitiesBySourceAndVulnIds(identifiers)) {
result.put(new VulnIdAndSource(vuln.getVulnId(), vuln.getSource()), vuln);
}
return result;
}

private Set<VulnerabilityAddition> processComponents(
QueryManager qm,
Map<Component, List<trivy.proto.common.Vulnerability>> vulnsByComponent,
Map<UUID, Component> persistentComponents,
Map<trivy.proto.common.Vulnerability, Vulnerability> parsedCache,
Map<VulnIdAndSource, Vulnerability> existingVulnerabilities
) {
Set<VulnerabilityAddition> additions = new HashSet<>();

for (var entry : vulnsByComponent.entrySet()) {
Component component = entry.getKey();
Component persistent = persistentComponents.get(component.getUuid());

if (persistent == null) {
LOGGER.warn("""
%s vulnerabilities were reported for component %s, \
but it no longer exists; Skipping""".formatted( entry.getValue().size(), component.getUuid()));
continue;
}

for (var trivyVuln : entry.getValue()) {
Vulnerability parsedVulnerability = parsedCache.get(trivyVuln);
VulnIdAndSource key = new VulnIdAndSource(parsedVulnerability.getVulnId(), parsedVulnerability.getSource());

Vulnerability persisted = existingVulnerabilities.get(key);

if (persisted == null) {
LOGGER.debug("Creating unavailable vulnerability: %s - %s".formatted(parsedVulnerability.getSource(), parsedVulnerability.getVulnId()));
persisted = qm.createVulnerability(parsedVulnerability, false);
existingVulnerabilities.put(key, persisted);
} else if (severityChanged(parsedVulnerability, persisted)) {
qm.updateVulnerability(parsedVulnerability, false);
}

additions.add(new VulnerabilityAddition(persisted, persistent));
}
}
return additions;
}

private void finalizeAdditions(
QueryManager qm,
Set<VulnerabilityAddition> additions
) {
int count = 0;

for (VulnerabilityAddition addition : additions) {
LOGGER.debug("Trivy vulnerability added: %s to component %s".formatted(addition.vulnerability.getVulnId(), addition.component.getName()));

NotificationUtil.analyzeNotificationCriteria(
qm, addition.vulnerability, addition.component, vulnerabilityAnalysisLevel
);
qm.addVulnerability(addition.vulnerability, addition.component, this.getAnalyzerIdentity());

if (++count % 20 == 0) {
qm.getPersistenceManager().flush();
}
}

qm.getPersistenceManager().flush();
}

private boolean severityChanged(Vulnerability parsed, Vulnerability existing) {
return parsed.getSeverity() != null
&& !parsed.getSeverity().equals(existing.getSeverity());
}

private boolean hasNewVulnerabilities(
Set<VulnerabilityAddition> additions,
Map<VulnIdAndSource, Vulnerability> existing
) {
return additions.stream()
.map(VulnerabilityAddition::vulnerability)
.anyMatch(v ->
!existing.containsKey(new VulnIdAndSource(v.getVulnId(), v.getSource()))
);
}

private record VulnerabilityAddition(Vulnerability vulnerability, Component component) {
}

private ArrayList<Result> analyzeBlob(final Collection<BlobInfo> blobs) {
Expand Down Expand Up @@ -479,39 +660,6 @@ private void deleteBlob(final PutBlobRequest putBlobRequest) {
}
}

private void handle(final Component component, final Collection<trivy.proto.common.Vulnerability> trivyVulns) {
try (final var qm = new QueryManager()) {
final var trivyParser = new TrivyParser();
final var persistentComponent = qm.getObjectByUuid(Component.class, component.getUuid());
if (persistentComponent == null) {
LOGGER.warn("""
%s vulnerabilities were reported for component %s, \
but it no longer exists; Skipping""".formatted(trivyVulns.size(), component.getUuid()));
return;
}

boolean didCreateVulns = false;
for (final trivy.proto.common.Vulnerability trivyVuln : trivyVulns) {
final Vulnerability parsedVulnerability = trivyParser.parse(trivyVuln);

Vulnerability vulnerability = qm.getVulnerabilityByVulnId(parsedVulnerability.getSource(), parsedVulnerability.getVulnId());
if (vulnerability == null) {
LOGGER.debug("Creating unavailable vulnerability:" + parsedVulnerability.getSource() + " - " + parsedVulnerability.getVulnId());
vulnerability = qm.createVulnerability(parsedVulnerability, false);
didCreateVulns = true;
}

LOGGER.debug("Trivy vulnerability added: " + vulnerability.getVulnId() + " to component " + persistentComponent.getName());
NotificationUtil.analyzeNotificationCriteria(qm, vulnerability, persistentComponent, vulnerabilityAnalysisLevel);
qm.addVulnerability(vulnerability, persistentComponent, this.getAnalyzerIdentity());
}

if (didCreateVulns) {
Event.dispatch(new IndexEvent(IndexEvent.Action.COMMIT, Vulnerability.class));
}
}
}

private Optional<String> getApiBaseUrl() {
if (apiBaseUrl != null) {
return Optional.of(apiBaseUrl);
Expand Down
Loading