Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 166 additions & 56 deletions modules/fundamental/src/vulnerability/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use trustify_common::{
purl::Purl,
};
use trustify_entity::{
advisory, advisory_vulnerability_score, organization, remediation::RemediationCategory,
advisory, advisory_vulnerability_score, cpe, organization, remediation::RemediationCategory,
vulnerability, vulnerability_description,
};
use trustify_module_ingestor::common::Deprecation;
Expand All @@ -48,6 +48,7 @@ struct AnalysisData {
descriptions_map: HashMap<String, vulnerability_description::Model>,
scores: Vec<advisory_vulnerability_score::Model>,
advisories_map: HashMap<Uuid, AdvisoryData>,
cpe_map: HashMap<Uuid, cpe::Model>,
}

#[derive(Default)]
Expand Down Expand Up @@ -176,30 +177,38 @@ impl VulnerabilityService {
purls_with_vulnerabilities.len()
);

// Extract advisory IDs from JSONB array
// Extract advisory IDs and CPE IDs from JSONB array
#[derive(serde::Deserialize)]
struct AdvisoryEntry {
advisory_id: Uuid,
context_cpe: Option<Uuid>,
}
impl sea_orm::TryGetableFromJson for AdvisoryEntry {}

// Extract all unique vulnerability_ids and advisory_ids in a single pass
// Extract all unique vulnerability_ids, advisory_ids, and cpe_ids in a single pass
let mut vulnerability_ids: Vec<String> =
Vec::with_capacity(purls_with_vulnerabilities.len());
let mut advisory_ids_set: std::collections::HashSet<Uuid> =
std::collections::HashSet::new();
let mut cpe_ids_set: std::collections::HashSet<Uuid> = std::collections::HashSet::new();

for purl_with_vulnerabilities in &purls_with_vulnerabilities {
vulnerability_ids.push(purl_with_vulnerabilities.try_get("", "id")?);

if let Some(advisories) = purl_with_vulnerabilities
.try_get_by::<Option<Vec<AdvisoryEntry>>, _>("advisories")?
{
advisory_ids_set.extend(advisories.into_iter().map(|e| e.advisory_id));
for entry in advisories {
advisory_ids_set.insert(entry.advisory_id);
if let Some(cpe_id) = entry.context_cpe {
cpe_ids_set.insert(cpe_id);
}
}
}
}

let advisory_ids: Vec<Uuid> = advisory_ids_set.into_iter().collect();
let cpe_ids: Vec<Uuid> = cpe_ids_set.into_iter().collect();

// Pre-fetch vulnerability descriptions
let vulnerability_descriptions = if !vulnerability_ids.is_empty() {
Expand Down Expand Up @@ -272,12 +281,27 @@ impl VulnerabilityService {
})
.collect();

// Pre-fetch CPEs for product_status context
let cpe_map: HashMap<Uuid, cpe::Model> = if !cpe_ids.is_empty() {
cpe::Entity::find()
.filter(Expr::col(cpe::Column::Id).eq(PgFunc::any(cpe_ids)))
.all(connection)
.await?
.into_iter()
.map(|c| (c.id, c))
.collect()
} else {
HashMap::new()
};
log::debug!("Pre-fetched {} CPEs", cpe_map.len());

Ok(AnalysisData {
purls_with_vulnerabilities,
warnings,
descriptions_map,
scores,
advisories_map: advisories,
cpe_map,
})
}

Expand All @@ -296,6 +320,7 @@ impl VulnerabilityService {
descriptions_map,
scores,
advisories_map,
cpe_map,
} = data;

// Build a map of (advisory_id, vulnerability_id) -> Vec<Model> for score calculation
Expand All @@ -317,6 +342,7 @@ impl VulnerabilityService {
&descriptions_map,
&scores_map,
&advisories_map,
&cpe_map,
)
.await?;

Expand Down Expand Up @@ -356,6 +382,7 @@ impl VulnerabilityService {
descriptions_map,
scores,
advisories_map,
cpe_map: _,
} = data;

let mut scores_map: HashMap<(Uuid, String), Vec<Score>> = HashMap::new();
Expand Down Expand Up @@ -400,40 +427,16 @@ impl VulnerabilityService {
Ok(AnalysisResponse(result))
}

/// Build the query for finding matching vulnerabilities
fn build_query(
purls: impl IntoIterator<Item = impl AsRef<str>>,
connection: &impl ConnectionTrait,
warnings: &mut HashMap<String, Vec<String>>,
) -> Result<String, Error> {
let query = purls
.into_iter()
.map(|p| {
let p = p.as_ref();
let purl = Purl::from_str(p)?;

let Some(version) = purl.version else {
warnings
.entry(p.to_string())
.or_default()
.push("Unable to process: missing version component".to_string());
return Ok(None);
};

let ns_condition = match &purl.namespace {
Some(namespace) => {
let sql = "base_purl.namespace = $1";
Statement::from_sql_and_values(
connection.get_database_backend(),
sql,
[namespace.into()],
)
.to_string()
}
None => "base_purl.namespace IS NULL".to_string(),
};

let sql = format!(
#[inline(always)]
/// Builds each individual part of the vulnerabilities query with parameters for
/// querying for either purl status vulnerabilities or product status vulnerabilities.
fn build_vulnerabilities_query_string(
advisory_columns: &str,
remediations_tables: &str,
vulnerabilities_tables: &str,
conditions: &str,
) -> String {
format!(
r#"
SELECT
$1 as requested_purl,
Expand All @@ -447,7 +450,7 @@ SELECT
jsonb_agg(
DISTINCT jsonb_build_object(
'status', status.slug,
'advisory_id', purl_status.advisory_id,
{advisory_columns},
'version_range', jsonb_build_object(
'version_scheme_id', version_range.version_scheme_id,
'low_version', version_range.low_version,
Expand All @@ -465,21 +468,12 @@ SELECT
'data', r.data
)
), '[]'::jsonb)
FROM remediation_purl_status rps
JOIN remediation r ON r.id = rps.remediation_id
WHERE rps.purl_status_id = purl_status.id
FROM {remediations_tables}
)
)
) AS advisories
FROM base_purl
LEFT JOIN purl_status ON base_purl.id = purl_status.base_purl_id
INNER JOIN version_range ON purl_status.version_range_id = version_range.id
LEFT JOIN vulnerability ON purl_status.vulnerability_id = vulnerability.id
INNER JOIN status ON purl_status.status_id = status.id
WHERE {ns_condition}
AND base_purl.name = $2
AND base_purl.type = $3
AND version_matches($4, version_range.*) = TRUE
FROM {vulnerabilities_tables}
WHERE {conditions}
AND status.slug NOT IN (
'fixed',
'not_affected',
Expand All @@ -495,14 +489,123 @@ GROUP BY
vulnerability.cwes,
requested_purl
"#
)
}

/// Build the query for finding matching vulnerabilities
fn build_query(
purls: impl IntoIterator<Item = impl AsRef<str>>,
connection: &impl ConnectionTrait,
warnings: &mut HashMap<String, Vec<String>>,
) -> Result<String, Error> {
let query = purls
.into_iter()
.map(|p| {
let p = p.as_ref();
let purl = Purl::from_str(p)?;

let Some(version) = purl.version else {
warnings
.entry(p.to_string())
.or_default()
.push("Unable to process: missing version component".to_string());
return Ok(None);
};

let ns_condition = match &purl.namespace {
Some(namespace) => {
let sql = "base_purl.namespace = $1";
Statement::from_sql_and_values(
connection.get_database_backend(),
sql,
[namespace.into()],
)
.to_string()
}
None => "base_purl.namespace IS NULL".to_string(),
};

let purl_status_sql = Self::build_vulnerabilities_query_string(
r#"'advisory_id', purl_status.advisory_id"#,
r#" remediation_purl_status rps
JOIN remediation r ON r.id = rps.remediation_id
WHERE rps.purl_status_id = purl_status.id
"#,
r#" base_purl
LEFT JOIN purl_status ON base_purl.id = purl_status.base_purl_id
INNER JOIN version_range ON purl_status.version_range_id = version_range.id
LEFT JOIN vulnerability ON purl_status.vulnerability_id = vulnerability.id
INNER JOIN status ON purl_status.status_id = status.id
"#,
format!(r#" {ns_condition}
AND base_purl.name = $2
AND base_purl.type = $3
AND version_matches($4, version_range.*) = TRUE
"#).as_str()
);
let query = Statement::from_sql_and_values(

let purl_status_query = Statement::from_sql_and_values(
connection.get_database_backend(),
&sql,
[p.into(), purl.name.into(), purl.ty.into(), version.into()],
&purl_status_sql,
[
p.into(),
purl.name.clone().into(),
purl.ty.clone().into(),
version.into(),
],
);

Ok(Some(query.to_string()))
let package_condition = match &purl.namespace {
Some(namespace) => {
let full_name = format!("{}/{}", namespace, purl.name);
let sql = "(product_status.package = $1 OR product_status.package = $2)";
Statement::from_sql_and_values(
connection.get_database_backend(),
sql,
[full_name.into(), purl.name.clone().into()],
)
.to_string()
}
None => {
let sql = "product_status.package = $1";
Statement::from_sql_and_values(
connection.get_database_backend(),
sql,
[purl.name.clone().into()],
)
.to_string()
}
};

let product_status_sql = Self::build_vulnerabilities_query_string(
r#" 'advisory_id', product_status.advisory_id,
'context_cpe', cpe.id
"#,
r#" remediation_product_status rps
JOIN remediation r ON r.id = rps.remediation_id
WHERE rps.product_status_id = product_status.id
"#,
r#" product_status
JOIN status ON product_status.status_id = status.id
JOIN vulnerability ON product_status.vulnerability_id = vulnerability.id
JOIN product_version_range ON product_status.product_version_range_id = product_version_range.id
JOIN version_range ON product_version_range.version_range_id = version_range.id
LEFT JOIN cpe ON product_status.context_cpe_id = cpe.id
"#,
format!(r#" {package_condition}
AND product_status.package IS NOT NULL
"#).as_str()
);
let product_status_query = Statement::from_sql_and_values(
connection.get_database_backend(),
&product_status_sql,
[p.into()],
);

Ok(Some(format!(
"{} UNION ALL {}",
purl_status_query, product_status_query
)))
})
.filter_map(Result::transpose)
.collect::<Result<Vec<String>, Error>>()?
Expand All @@ -525,6 +628,7 @@ GROUP BY
descriptions_map: &HashMap<String, vulnerability_description::Model>,
scores_map: &HashMap<(Uuid, String), Vec<advisory_vulnerability_score::Model>>,
advisories_map: &HashMap<Uuid, AdvisoryData>,
cpe_map: &HashMap<Uuid, cpe::Model>,
) -> Result<(String, AnalysisDetailsV3), Error>
where
C: ConnectionTrait,
Expand Down Expand Up @@ -557,6 +661,7 @@ GROUP BY
advisory_id: Uuid,
version_range: VersionRange,
remediations: Vec<RemediationEntry>,
context_cpe: Option<Uuid>,
}
impl sea_orm::TryGetableFromJson for AdvisoryEntry {}

Expand Down Expand Up @@ -614,6 +719,11 @@ GROUP BY
})
.collect();

let cpe_string = entry
.context_cpe
.and_then(|id| cpe_map.get(&id))
.map(|c| c.to_string());

let purl_status = PurlStatus::from_head(
head.clone(),
AdvisoryHead::from_advisory(
Expand All @@ -624,7 +734,7 @@ GROUP BY
.await?,
entry.status.clone(),
Some(entry.version_range.clone()),
None,
cpe_string,
score_models,
)?;
purl_statuses.push(AnalysisPurlStatus {
Expand Down
Loading