|
| 1 | +// Copyright (c) Microsoft Corporation. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +use windows::{ |
| 5 | + core::*, |
| 6 | + Win32::Foundation::*, |
| 7 | + Win32::System::Com::*, |
| 8 | + Win32::System::UpdateAgent::*, |
| 9 | +}; |
| 10 | + |
| 11 | +use crate::windows_update::types::{UpdateInput, UpdateInfo, MsrcSeverity, UpdateType}; |
| 12 | + |
| 13 | +pub fn handle_export(input: &str) -> Result<String> { |
| 14 | + // Parse optional filter input |
| 15 | + let filter: UpdateInput = if input.trim().is_empty() { |
| 16 | + UpdateInput { |
| 17 | + title: None, |
| 18 | + id: None, |
| 19 | + is_installed: None, |
| 20 | + description: None, |
| 21 | + is_uninstallable: None, |
| 22 | + kb_article_ids: None, |
| 23 | + max_download_size: None, |
| 24 | + msrc_severity: None, |
| 25 | + security_bulletin_ids: None, |
| 26 | + update_type: None, |
| 27 | + } |
| 28 | + } else { |
| 29 | + serde_json::from_str(input) |
| 30 | + .map_err(|e| Error::new(E_INVALIDARG, format!("Failed to parse input: {}", e)))? |
| 31 | + }; |
| 32 | + |
| 33 | + // Initialize COM |
| 34 | + unsafe { |
| 35 | + CoInitializeEx(Some(std::ptr::null()), COINIT_MULTITHREADED).ok()?; |
| 36 | + } |
| 37 | + |
| 38 | + let result = unsafe { |
| 39 | + // Create update session |
| 40 | + let update_session: IUpdateSession = CoCreateInstance( |
| 41 | + &UpdateSession, |
| 42 | + None, |
| 43 | + CLSCTX_INPROC_SERVER, |
| 44 | + )?; |
| 45 | + |
| 46 | + // Create update searcher |
| 47 | + let searcher = update_session.CreateUpdateSearcher()?; |
| 48 | + |
| 49 | + // Build search criteria based on filters |
| 50 | + let search_criteria = match filter.is_installed { |
| 51 | + Some(true) => "IsInstalled=1", |
| 52 | + Some(false) => "IsInstalled=0", |
| 53 | + None => "IsInstalled=0 or IsInstalled=1", |
| 54 | + }; |
| 55 | + |
| 56 | + // Search for updates with optimized criteria |
| 57 | + let search_result = searcher.Search(&BSTR::from(search_criteria))?; |
| 58 | + |
| 59 | + // Get updates collection |
| 60 | + let updates = search_result.Updates()?; |
| 61 | + let count = updates.Count()?; |
| 62 | + |
| 63 | + // Collect all matching updates |
| 64 | + let mut found_updates: Vec<UpdateInfo> = Vec::new(); |
| 65 | + for i in 0..count { |
| 66 | + let update = updates.get_Item(i)?; |
| 67 | + let title = update.Title()?.to_string(); |
| 68 | + let identity = update.Identity()?; |
| 69 | + let update_id = identity.UpdateID()?.to_string(); |
| 70 | + |
| 71 | + // Extract all update information first for filtering |
| 72 | + let is_installed = update.IsInstalled()?.as_bool(); |
| 73 | + let description = update.Description()?.to_string(); |
| 74 | + let is_uninstallable = update.IsUninstallable()?.as_bool(); |
| 75 | + |
| 76 | + // Get KB Article IDs |
| 77 | + let kb_articles = update.KBArticleIDs()?; |
| 78 | + let kb_count = kb_articles.Count()?; |
| 79 | + let mut kb_article_ids = Vec::new(); |
| 80 | + for j in 0..kb_count { |
| 81 | + if let Ok(kb_str) = kb_articles.get_Item(j) { |
| 82 | + kb_article_ids.push(kb_str.to_string()); |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + let max_download_size = 0i64; |
| 87 | + |
| 88 | + // Get MSRC Severity |
| 89 | + let msrc_severity = if let Ok(severity_str) = update.MsrcSeverity() { |
| 90 | + match severity_str.to_string().as_str() { |
| 91 | + "Critical" => Some(MsrcSeverity::Critical), |
| 92 | + "Important" => Some(MsrcSeverity::Important), |
| 93 | + "Moderate" => Some(MsrcSeverity::Moderate), |
| 94 | + "Low" => Some(MsrcSeverity::Low), |
| 95 | + _ => None, |
| 96 | + } |
| 97 | + } else { |
| 98 | + None |
| 99 | + }; |
| 100 | + |
| 101 | + // Get Security Bulletin IDs |
| 102 | + let security_bulletins = update.SecurityBulletinIDs()?; |
| 103 | + let bulletin_count = security_bulletins.Count()?; |
| 104 | + let mut security_bulletin_ids = Vec::new(); |
| 105 | + for j in 0..bulletin_count { |
| 106 | + if let Ok(bulletin_str) = security_bulletins.get_Item(j) { |
| 107 | + security_bulletin_ids.push(bulletin_str.to_string()); |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + // Determine update type |
| 112 | + let update_type = { |
| 113 | + use windows::Win32::System::UpdateAgent::UpdateType as WinUpdateType; |
| 114 | + match update.Type()? { |
| 115 | + WinUpdateType(2) => UpdateType::Driver, |
| 116 | + _ => UpdateType::Software, |
| 117 | + } |
| 118 | + }; |
| 119 | + |
| 120 | + // Apply all filters |
| 121 | + let mut matches = true; |
| 122 | + |
| 123 | + // Filter by title with wildcard support |
| 124 | + if let Some(title_filter) = &filter.title { |
| 125 | + matches = matches && matches_wildcard(&title, title_filter); |
| 126 | + } |
| 127 | + |
| 128 | + // Filter by id |
| 129 | + if let Some(id_filter) = &filter.id { |
| 130 | + matches = matches && update_id.eq_ignore_ascii_case(id_filter); |
| 131 | + } |
| 132 | + |
| 133 | + // Filter by description with wildcard support |
| 134 | + if let Some(desc_filter) = &filter.description { |
| 135 | + matches = matches && matches_wildcard(&description, desc_filter); |
| 136 | + } |
| 137 | + |
| 138 | + // Filter by is_uninstallable |
| 139 | + if let Some(uninstallable_filter) = filter.is_uninstallable { |
| 140 | + matches = matches && (is_uninstallable == uninstallable_filter); |
| 141 | + } |
| 142 | + |
| 143 | + // Filter by KB article IDs (match if any KB ID in the filter is present) |
| 144 | + if let Some(kb_filter) = &filter.kb_article_ids { |
| 145 | + if !kb_filter.is_empty() { |
| 146 | + let kb_matches = kb_filter.iter().any(|filter_kb| { |
| 147 | + kb_article_ids.iter().any(|update_kb| update_kb.eq_ignore_ascii_case(filter_kb)) |
| 148 | + }); |
| 149 | + matches = matches && kb_matches; |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + // Filter by max_download_size (if specified, update size must be <= filter size) |
| 154 | + if let Some(size_filter) = filter.max_download_size { |
| 155 | + matches = matches && (max_download_size <= size_filter); |
| 156 | + } |
| 157 | + |
| 158 | + // Filter by MSRC severity |
| 159 | + if let Some(severity_filter) = &filter.msrc_severity { |
| 160 | + matches = matches && (msrc_severity.as_ref() == Some(severity_filter)); |
| 161 | + } |
| 162 | + |
| 163 | + // Filter by security bulletin IDs (match if any bulletin ID in the filter is present) |
| 164 | + if let Some(bulletin_filter) = &filter.security_bulletin_ids { |
| 165 | + if !bulletin_filter.is_empty() { |
| 166 | + let bulletin_matches = bulletin_filter.iter().any(|filter_bulletin| { |
| 167 | + security_bulletin_ids.iter().any(|update_bulletin| update_bulletin.eq_ignore_ascii_case(filter_bulletin)) |
| 168 | + }); |
| 169 | + matches = matches && bulletin_matches; |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + // Filter by update type |
| 174 | + if let Some(type_filter) = &filter.update_type { |
| 175 | + matches = matches && (&update_type == type_filter); |
| 176 | + } |
| 177 | + |
| 178 | + if matches { |
| 179 | + found_updates.push(UpdateInfo { |
| 180 | + title, |
| 181 | + is_installed, |
| 182 | + description, |
| 183 | + id: update_id, |
| 184 | + is_uninstallable, |
| 185 | + kb_article_ids, |
| 186 | + max_download_size, |
| 187 | + msrc_severity, |
| 188 | + security_bulletin_ids, |
| 189 | + update_type, |
| 190 | + }); |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + Ok(found_updates) |
| 195 | + }; |
| 196 | + |
| 197 | + unsafe { |
| 198 | + CoUninitialize(); |
| 199 | + } |
| 200 | + |
| 201 | + match result { |
| 202 | + Ok(updates) => serde_json::to_string(&updates) |
| 203 | + .map_err(|e| Error::new(E_FAIL, format!("Failed to serialize output: {}", e))), |
| 204 | + Err(e) => Err(e), |
| 205 | + } |
| 206 | +} |
| 207 | + |
| 208 | +// Helper function to match string against pattern with wildcard (*) |
| 209 | +fn matches_wildcard(text: &str, pattern: &str) -> bool { |
| 210 | + let text_lower = text.to_lowercase(); |
| 211 | + let pattern_lower = pattern.to_lowercase(); |
| 212 | + |
| 213 | + // Split pattern by asterisks |
| 214 | + let parts: Vec<&str> = pattern_lower.split('*').collect(); |
| 215 | + |
| 216 | + // If no wildcard, it's an exact match (case-insensitive) |
| 217 | + if parts.len() == 1 { |
| 218 | + return text_lower == pattern_lower; |
| 219 | + } |
| 220 | + |
| 221 | + // If pattern is just asterisk(s), match everything |
| 222 | + if parts.is_empty() { |
| 223 | + return true; |
| 224 | + } |
| 225 | + |
| 226 | + // Check if pattern starts with asterisk |
| 227 | + let starts_with_wildcard = pattern_lower.starts_with('*'); |
| 228 | + // Check if pattern ends with asterisk |
| 229 | + let ends_with_wildcard = pattern_lower.ends_with('*'); |
| 230 | + |
| 231 | + let mut pos = 0; |
| 232 | + |
| 233 | + for (i, part) in parts.iter().enumerate() { |
| 234 | + if part.is_empty() { |
| 235 | + continue; |
| 236 | + } |
| 237 | + |
| 238 | + // For the first part, check if it should be at the start |
| 239 | + if i == 0 && !starts_with_wildcard { |
| 240 | + if !text_lower.starts_with(part) { |
| 241 | + return false; |
| 242 | + } |
| 243 | + pos = part.len(); |
| 244 | + } else { |
| 245 | + // Find the part in the remaining text |
| 246 | + if let Some(found_pos) = text_lower[pos..].find(part) { |
| 247 | + pos += found_pos + part.len(); |
| 248 | + } else { |
| 249 | + return false; |
| 250 | + } |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + // For the last part, check if it should be at the end |
| 255 | + if !ends_with_wildcard && !parts.is_empty() { |
| 256 | + if let Some(last_part) = parts.last() { |
| 257 | + if !last_part.is_empty() && !text_lower.ends_with(last_part) { |
| 258 | + return false; |
| 259 | + } |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + true |
| 264 | +} |
0 commit comments