diff --git a/cryptoki/examples/benchmark_attributes.rs b/cryptoki/examples/benchmark_attributes.rs new file mode 100644 index 00000000..ebcb3908 --- /dev/null +++ b/cryptoki/examples/benchmark_attributes.rs @@ -0,0 +1,332 @@ +// Copyright 2025 Contributors to the Parsec project. +// SPDX-License-Identifier: Apache-2.0 +//! Benchmark example comparing get_attributes_old vs get_attributes +//! +//! This example demonstrates the performance difference between the original +//! and optimized implementations for retrieving object attributes. + +use cryptoki::context::{CInitializeArgs, Pkcs11}; +use cryptoki::mechanism::Mechanism; +use cryptoki::object::{Attribute, AttributeType, ObjectHandle}; +use cryptoki::session::{Session, UserType}; +use cryptoki::types::AuthPin; +use std::env; +use std::time::Instant; + +/// Statistics for a benchmark run +/// API calls are typically log-normally distributed, so we use that distribution +/// to compute geometric mean and percentiles. +struct BenchmarkStats { + mean: f64, + stddev: f64, + p50: f64, + p95: f64, + p99: f64, +} + +impl BenchmarkStats { + fn from_timings(mut timings: Vec) -> Self { + let iterations = timings.len(); + timings.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + let p50 = timings[iterations / 2]; + let p95 = timings[(iterations * 95) / 100]; + let p99 = timings[(iterations * 99) / 100]; + + // Geometric mean (appropriate for log-normal distribution) + let mean = (timings.iter().map(|x| x.ln()).sum::() / iterations as f64).exp(); + + // Standard deviation in log-space (geometric standard deviation) + let log_mean = timings.iter().map(|x| x.ln()).sum::() / iterations as f64; + let log_variance = timings + .iter() + .map(|x| (x.ln() - log_mean).powi(2)) + .sum::() + / iterations as f64; + let stddev = log_variance.sqrt().exp(); + + BenchmarkStats { + mean, + stddev, + p50, + p95, + p99, + } + } + + fn print(&self, label: &str) { + println!(" {}:", label); + println!(" distribution: log-normal"); + println!(" mean (geom): {:.2} µs", self.mean / 1000.0); + println!(" std dev (geom): {:.2}x", self.stddev); + println!(" p50 (median): {:.2} µs", self.p50 / 1000.0); + println!(" p95: {:.2} µs", self.p95 / 1000.0); + println!(" p99: {:.2} µs", self.p99 / 1000.0); + } +} + +struct BenchmarkResult { + label: String, + stats_old: BenchmarkStats, + stats_optimized: BenchmarkStats, +} + +impl BenchmarkResult { + fn speedup_mean(&self) -> f64 { + self.stats_old.mean / self.stats_optimized.mean + } +} + +/// Run a benchmark comparing get_attributes_old vs get_attributes +fn benchmark_attributes( + session: &Session, + object: ObjectHandle, + attributes: &[AttributeType], + iterations: usize, + label: &str, +) -> Result> { + println!("\n=== {} ===", label); + + // Benchmark get_attributes_old (original implementation) + println!( + "Benchmarking get_attributes_old() - {} iterations...", + iterations + ); + let mut timings_old = Vec::with_capacity(iterations); + for _ in 0..iterations { + let start = Instant::now(); + let _attrs = session.get_attributes_old(object, attributes)?; + timings_old.push(start.elapsed().as_nanos() as f64); + } + + // Benchmark get_attributes (optimized implementation) + println!( + "Benchmarking get_attributes() - {} iterations...", + iterations + ); + let mut timings_optimized = Vec::with_capacity(iterations); + for _ in 0..iterations { + let start = Instant::now(); + let _attrs = session.get_attributes(object, attributes)?; + timings_optimized.push(start.elapsed().as_nanos() as f64); + } + + let stats_old = BenchmarkStats::from_timings(timings_old); + let stats_optimized = BenchmarkStats::from_timings(timings_optimized); + + println!("\nResults:"); + stats_old.print("Original implementation"); + stats_optimized.print("Optimized implementation"); + + let speedup_mean = stats_old.mean / stats_optimized.mean; + let speedup_p95 = stats_old.p95 / stats_optimized.p95; + println!("\nSpeedup:"); + println!(" Based on mean (geom): {:.2}x", speedup_mean); + println!(" Based on p95: {:.2}x", speedup_p95); + + // Verify both methods return the same results + let attrs_old = session.get_attributes_old(object, attributes)?; + let attrs_optimized = session.get_attributes(object, attributes)?; + + println!("\nVerifying correctness..."); + println!( + " Original implementation returned {} attributes", + attrs_old.len() + ); + println!( + " Optimized implementation returned {} attributes", + attrs_optimized.len() + ); + + if attrs_old.len() != attrs_optimized.len() { + println!(" ✗ Implementations returned different number of attributes!"); + } else { + println!(" ✓ Both implementations returned the same number of attributes"); + + // Verify the order is the same + let mut order_matches = true; + for (i, (old_attr, opt_attr)) in attrs_old.iter().zip(attrs_optimized.iter()).enumerate() { + if std::mem::discriminant(old_attr) != std::mem::discriminant(opt_attr) { + println!( + " ✗ Attribute at position {} differs: {:?} vs {:?}", + i, old_attr, opt_attr + ); + order_matches = false; + } + } + + if order_matches { + println!(" ✓ Attributes are in the same order"); + } + } + + Ok(BenchmarkResult { + label: label.to_string(), + stats_old, + stats_optimized, + }) +} + +fn print_summary_table(results: &[BenchmarkResult]) { + println!("\n"); + println!("╔═══════════════════════════════════════════════════════════════════════════════════════════════════╗"); + println!("║ BENCHMARK SUMMARY TABLE ║"); + println!("╠═══════════════════╦═════════════╦═════════════╦═════════════╦═════════════╦═══════╦═══════════════╣"); + println!( + "║ {:^17} ║ {:>11} ║ {:>11} ║ {:>11} ║ {:>11} ║ {:^5} ║ {:^13} ║", + "Test Case", "Orig Mean", "Orig p95", "Opt Mean", "Opt p95", "Unit", "Speedup" + ); + println!("╠═══════════════════╬═════════════╬═════════════╬═════════════╬═════════════╬═══════╬═══════════════╣"); + + // Each row is a test case + for result in results { + println!( + "║ {:17} ║ {:11.2} ║ {:11.2} ║ {:11.2} ║ {:11.2} ║ {:>5} ║ {:>13} ║", + result.label, + result.stats_old.mean / 1000.0, + result.stats_old.p95 / 1000.0, + result.stats_optimized.mean / 1000.0, + result.stats_optimized.p95 / 1000.0, + "µs", + format!("x {:.2}", result.speedup_mean()) + ); + } + + println!("╚═══════════════════╩═════════════╩═════════════╩═════════════╩═════════════╩═══════╩═══════════════╝"); +} + +fn main() -> Result<(), Box> { + // how many iterations to run, default to 1000 + let iterations = env::var("TEST_BENCHMARK_ITERATIONS") + .unwrap_or_else(|_| "1000".to_string()) + .parse::()?; + + let pkcs11 = Pkcs11::new( + env::var("TEST_PKCS11_MODULE") + .unwrap_or_else(|_| "/usr/lib/softhsm/libsofthsm2.so".to_string()), + )?; + + let pin = env::var("TEST_PKCS11_PIN").unwrap_or_else(|_| "fedcba123456".to_string()); + pkcs11.initialize(CInitializeArgs::OsThreads)?; + + let nogenerate = env::var("TEST_PKCS11_NO_KEYGEN").is_ok(); + + let slot = pkcs11 + .get_slots_with_token()? + .into_iter() + .next() + .ok_or("No slot available")?; + + let session = pkcs11.open_rw_session(slot)?; + + session.login(UserType::User, Some(&AuthPin::new(pin.into())))?; + + let public; + let _private; + + if nogenerate { + // search for an elliptic curve public key. + // if more than one, take the first that comes. + println!("Using existing EC public key for benchmarking..."); + let template = vec![ + Attribute::Class(cryptoki::object::ObjectClass::PUBLIC_KEY), + Attribute::KeyType(cryptoki::object::KeyType::EC), + ]; + let objects = session.find_objects(&template)?; + if objects.is_empty() { + return Err( + "No EC public key found on the token. Cannot proceed with benchmarks.".into(), + ); + } + public = objects[0]; + } else { + // Generate a test EC key pair (P-256 curve) + let mechanism = Mechanism::EccKeyPairGen; + + // ANSI X9.62 prime256v1 (P-256) curve OID: 1.2.840.10045.3.1.7 + let ec_params = vec![0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07]; + + let pub_key_template = vec![ + Attribute::Token(false), // Don't persist + Attribute::Private(false), + Attribute::EcParams(ec_params), + Attribute::Verify(true), + Attribute::Label("Benchmark EC Key".into()), + Attribute::Id(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + ]; + + let priv_key_template = vec![Attribute::Token(false), Attribute::Sign(true)]; + + println!("Generating EC key pair for benchmarking..."); + (public, _private) = + session.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)?; + } + + let mut results = Vec::new(); + + // Test 1: Multiple attributes (mix of fixed and variable length) + let multiple_attributes = vec![ + AttributeType::Class, // CK_ULONG (fixed, 8 bytes) + AttributeType::Label, // Variable length + AttributeType::Id, // Variable length + AttributeType::KeyType, // CK_ULONG (fixed, 8 bytes) + AttributeType::Token, // CK_BBOOL (c_uchar, 1 byte) + AttributeType::Private, // CK_BBOOL (c_uchar, 1 byte) + AttributeType::EcPoint, // Variable length (~65 bytes for P-256 uncompressed) + AttributeType::EcParams, // Variable length (10 bytes for P-256 OID) + AttributeType::Verify, // CK_BBOOL (c_uchar, 1 byte) + AttributeType::Encrypt, // CK_BBOOL (c_uchar, 1 byte) + AttributeType::Local, // CK_BBOOL (c_uchar, 1 byte) + ]; + + results.push(benchmark_attributes( + &session, + public, + &multiple_attributes, + iterations, + "Multiple", + )?); + + // Test 2: Single fixed-length attribute (CK_ULONG) + let single_fixed = vec![AttributeType::KeyType]; + + results.push(benchmark_attributes( + &session, + public, + &single_fixed, + iterations, + "Single-fixed", + )?); + + // Test 3: Single variable-length attribute (EC point, ~65 bytes for P-256) + let single_variable = vec![AttributeType::EcPoint]; + + results.push(benchmark_attributes( + &session, + public, + &single_variable, + iterations, + "Single-variable", + )?); + + // Test 4: Single attribute that doesn't exist (Modulus for EC key) + let single_nonexistent = vec![AttributeType::Modulus]; + + results.push(benchmark_attributes( + &session, + public, + &single_nonexistent, + iterations, + "Single-nonexist", + )?); + + // Print summary table + print_summary_table(&results); + + // Clean up + if !nogenerate { + session.destroy_object(public)?; + } + + Ok(()) +} diff --git a/cryptoki/src/object.rs b/cryptoki/src/object.rs index 463a1007..71c5467b 100644 --- a/cryptoki/src/object.rs +++ b/cryptoki/src/object.rs @@ -322,6 +322,86 @@ impl AttributeType { _ => format!("unknown ({val:08x})"), } } + + /// Returns the fixed size of an attribute type if known. + /// + /// This method returns `Some(size)` for attributes with a known fixed size, + /// and `None` for variable-length attributes. This is useful for optimizing + /// attribute retrieval by pre-allocating buffers of the correct size. + /// + /// # Returns + /// + /// * `Some(usize)` - The fixed size in bytes for attributes with known fixed size + /// * `None` - For variable-length attributes (e.g., Label, Modulus, Value, etc.) + /// + /// # Examples + /// + /// ``` + /// use cryptoki::object::AttributeType; + /// use std::mem::size_of; + /// use cryptoki_sys::{CK_ULONG, CK_BBOOL}; + /// + /// // Fixed-size attributes + /// assert_eq!(AttributeType::Class.fixed_size(), Some(size_of::())); + /// assert_eq!(AttributeType::Token.fixed_size(), Some(size_of::())); + /// + /// // Variable-length attributes + /// assert_eq!(AttributeType::Label.fixed_size(), None); + /// assert_eq!(AttributeType::Modulus.fixed_size(), None); + /// ``` + pub fn fixed_size(&self) -> Option { + match self { + // CK_BBOOL + AttributeType::Token + | AttributeType::Private + | AttributeType::Modifiable + | AttributeType::Copyable + | AttributeType::Destroyable + | AttributeType::Sensitive + | AttributeType::Encrypt + | AttributeType::Decrypt + | AttributeType::Wrap + | AttributeType::Unwrap + | AttributeType::Sign + | AttributeType::SignRecover + | AttributeType::Verify + | AttributeType::VerifyRecover + | AttributeType::Derive + | AttributeType::Extractable + | AttributeType::Local + | AttributeType::NeverExtractable + | AttributeType::AlwaysSensitive + | AttributeType::WrapWithTrusted + | AttributeType::Trusted + | AttributeType::AlwaysAuthenticate + | AttributeType::Encapsulate + | AttributeType::Decapsulate => Some(size_of::()), + + // CK_ULONG or aliases (CK_OBJECT_CLASS, CK_KEY_TYPE, CK_CERTIFICATE_TYPE, etc.) + AttributeType::Class + | AttributeType::KeyType + | AttributeType::CertificateType + | AttributeType::ModulusBits + | AttributeType::ValueLen + | AttributeType::ObjectValidationFlags + | AttributeType::ParameterSet + | AttributeType::ValidationFlag + | AttributeType::ValidationType + | AttributeType::ValidationLevel + | AttributeType::ValidationAuthorityType + | AttributeType::ProfileId + | AttributeType::KeyGenMechanism => Some(size_of::()), + + // CK_DATE (8 bytes: year[4] + month[2] + day[2]) + AttributeType::StartDate | AttributeType::EndDate => Some(size_of::()), + + // CK_VERSION (2 bytes: major + minor) + AttributeType::ValidationVersion => Some(size_of::()), + + // Variable-length attributes (all the others) + _ => None, + } + } } impl std::fmt::Display for AttributeType { diff --git a/cryptoki/src/session/object_management.rs b/cryptoki/src/session/object_management.rs index 13034867..f7120c81 100644 --- a/cryptoki/src/session/object_management.rs +++ b/cryptoki/src/session/object_management.rs @@ -9,6 +9,7 @@ use crate::session::Session; use cryptoki_sys::*; use std::collections::HashMap; use std::convert::TryInto; +use std::ffi::c_void; use std::num::NonZeroUsize; // Search 10 elements at a time @@ -500,7 +501,7 @@ impl Session { /// Get the attributes values of an object. /// Ignore the unavailable one. One has to call the get_attribute_info method to check which /// ones are unavailable. - pub fn get_attributes( + pub fn get_attributes_old( &self, object: ObjectHandle, attributes: &[AttributeType], @@ -525,7 +526,7 @@ impl Session { .map(|(attr_type, memory)| { Ok(CK_ATTRIBUTE { type_: (*attr_type).into(), - pValue: memory.as_ptr() as *mut std::ffi::c_void, + pValue: memory.as_ptr() as *mut c_void, ulValueLen: memory.len().try_into()?, }) }) @@ -547,6 +548,179 @@ impl Session { template.into_iter().map(|attr| attr.try_into()).collect() } + /// Get the attributes values of an object, filtering out unavailable ones. + /// + /// # Arguments + /// + /// * `object` - The [ObjectHandle] used to reference the object + /// * `attributes` - The list of attribute types to retrieve + /// + /// # Returns + /// + /// A vector of [Attribute] containing the values of the available attributes. + /// + /// # Note + /// + /// This method follows PKCS#11 spec: in the first call, it provides pre-allocated buffers + /// for attributes with a known fixed size, and NULL pointers for other attributes to query + /// their size. + /// + /// After the first call: + /// - Attributes that fit in pre-allocated buffers are ready (pValue != NULL, valid ulValueLen) + /// - Attributes with NULL pValue but valid ulValueLen need a second fetch + /// - Attributes with CK_UNAVAILABLE_INFORMATION are skipped + /// + /// In total, a maximum of 2 calls to C_GetAttributeValue are made. + /// + pub fn get_attributes( + &self, + object: ObjectHandle, + attributes: &[AttributeType], + ) -> Result> { + // Step 1: Build pass1 template + // - Pre-allocate buffers for attributes with known fixed size + // - Use NULL pointers for all other attributes to query their size + let mut buffers: Vec> = Vec::with_capacity(attributes.len()); + let mut template1: Vec = Vec::with_capacity(attributes.len()); + + for attr_type in attributes.iter() { + if let Some(size) = attr_type.fixed_size() { + // We know the needed size, we allocate + let mut buffer = vec![0u8; size]; + template1.push(CK_ATTRIBUTE { + type_: (*attr_type).into(), + pValue: buffer.as_mut_ptr() as *mut c_void, + ulValueLen: size as CK_ULONG, + }); + buffers.push(buffer); + } else { + // This is a variable size, we set length to 0 and set the buffer ptr to NULL + template1.push(CK_ATTRIBUTE { + type_: (*attr_type).into(), + pValue: std::ptr::null_mut(), + ulValueLen: 0, + }); + buffers.push(Vec::new()); + } + } + + // Step 2: Make pass1 call to C_GetAttributeValue + let rv1 = unsafe { + Rv::from(get_pkcs11!(self.client(), C_GetAttributeValue)( + self.handle(), + object.handle(), + template1.as_mut_ptr(), + template1.len().try_into()?, + )) + }; + + match rv1 { + Rv::Ok + | Rv::Error(RvError::BufferTooSmall) + | Rv::Error(RvError::AttributeSensitive) + | Rv::Error(RvError::AttributeTypeInvalid) => { + // acceptable - we'll inspect ulValueLen/pValue + } + _ => { + rv1.into_result(Function::GetAttributeValue)?; + } + } + + // Step 3: Categorize pass1 results into pass1 (already satisfied) and pass2 (need fetch) + let mut pass1_indices: Vec = Vec::new(); + let mut pass2_indices: Vec = Vec::new(); + + for (i, attr) in template1.iter().enumerate() { + if attr.ulValueLen == CK_UNAVAILABLE_INFORMATION { + // Skip unavailable attributes + continue; + } else if attr.pValue.is_null() && attr.ulValueLen > 0 { + // NULL pointer but has a length - needs fetching in pass2 + pass2_indices.push(i); + } else if !attr.pValue.is_null() { + // If buffer was pre-allocated but too small, need to fetch in pass2 + if attr.ulValueLen > buffers[i].len() as CK_ULONG { + pass2_indices.push(i); + } else { + // Has data already - satisfied in pass1 + pass1_indices.push(i); + } + } + } + + // Step 4: Make pass2 call if needed for attributes that need fetching + let pass2_template_and_indices: Option<(Vec, Vec)> = + if pass2_indices.is_empty() { + None + } else { + let mut template2: Vec = Vec::with_capacity(pass2_indices.len()); + + for &idx in pass2_indices.iter() { + let size = template1[idx].ulValueLen as usize; + buffers[idx].resize(size, 0); + + template2.push(CK_ATTRIBUTE { + type_: template1[idx].type_, + pValue: buffers[idx].as_mut_ptr() as *mut c_void, + ulValueLen: buffers[idx].len() as CK_ULONG, + }); + } + + let rv2 = unsafe { + Rv::from(get_pkcs11!(self.client(), C_GetAttributeValue)( + self.handle(), + object.handle(), + template2.as_mut_ptr(), + template2.len().try_into()?, + )) + }; + + match rv2 { + Rv::Ok + | Rv::Error(RvError::AttributeSensitive) + | Rv::Error(RvError::AttributeTypeInvalid) => { + // acceptable + } + _ => { + rv2.into_result(Function::GetAttributeValue)?; + } + } + + // Add indices satisfied by pass2 into pass1_indices + for (i, &idx) in pass2_indices.iter().enumerate() { + if template2[i].ulValueLen != CK_UNAVAILABLE_INFORMATION { + pass1_indices.push(idx); + } + } + + Some((template2, pass2_indices)) + }; + + // Step 5: Build result Vec preserving request order + // Sort pass1_indices to preserve the original order from attributes[] + pass1_indices.sort_unstable(); + + let mut results = Vec::new(); + for &idx in pass1_indices.iter() { + let attr = if let Some((ref template2, ref indices2)) = pass2_template_and_indices { + if let Some(pos) = indices2.iter().position(|&i| i == idx) { + // attribute came from pass2 + template2[pos].try_into()? + } else { + // attribute came from pass1 + template1[idx].try_into()? + } + } else { + // Only pass1 happened + template1[idx].try_into()? + }; + + results.push(attr); + } + + Ok(results) + } + /// Sets the attributes of an object pub fn update_attributes(&self, object: ObjectHandle, template: &[Attribute]) -> Result<()> { let mut template: Vec = template.iter().map(|attr| attr.into()).collect(); diff --git a/cryptoki/tests/basic.rs b/cryptoki/tests/basic.rs index f6bbde4f..a34fb87b 100644 --- a/cryptoki/tests/basic.rs +++ b/cryptoki/tests/basic.rs @@ -4348,3 +4348,137 @@ fn object_handle_new_from_raw() -> TestResult { Ok(()) } + +#[test] +#[serial] +fn get_attributes_test() -> TestResult { + let (pkcs11, slot) = init_pins(); + + // open a session + let session = pkcs11.open_rw_session(slot)?; + + // log in the session + session.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))?; + + // get mechanism + let mechanism = Mechanism::RsaPkcsKeyPairGen; + + let public_exponent: Vec = vec![0x01, 0x00, 0x01]; + let modulus_bits = 2048; + + // pub key template + let pub_key_template = vec![ + Attribute::Token(true), + Attribute::Private(false), + Attribute::PublicExponent(public_exponent.clone()), + Attribute::ModulusBits(modulus_bits.into()), + Attribute::Verify(true), + ]; + + // priv key template + let priv_key_template = vec![ + Attribute::Token(true), + Attribute::Sign(true), + Attribute::Private(true), + ]; + + // generate a key pair + let (public, private) = + session.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)?; + + // Test get_attributes_fast with various attribute types + let attributes_to_check = vec![ + AttributeType::Class, + AttributeType::KeyType, + AttributeType::Token, + AttributeType::Private, + AttributeType::Modulus, + AttributeType::PublicExponent, + AttributeType::Verify, + AttributeType::ModulusBits, + ]; + + // Test 1: Get multiple attributes from public key + let attrs = session.get_attributes(public, &attributes_to_check)?; + + // Check that we got the expected attributes + assert!(!attrs.is_empty(), "No attributes returned"); + + // Verify specific attributes + let has_class = attrs.iter().any(|attr| matches!(attr, Attribute::Class(_))); + let has_key_type = attrs + .iter() + .any(|attr| matches!(attr, Attribute::KeyType(_))); + let has_modulus = attrs + .iter() + .any(|attr| matches!(attr, Attribute::Modulus(_))); + let has_public_exp = attrs + .iter() + .any(|attr| matches!(attr, Attribute::PublicExponent(_))); + + assert!(has_class, "Class attribute not found"); + assert!(has_key_type, "KeyType attribute not found"); + assert!(has_modulus, "Modulus attribute not found"); + assert!(has_public_exp, "PublicExponent attribute not found"); + + // Verify the public exponent value matches what we set + for attr in &attrs { + if let Attribute::PublicExponent(exp) = attr { + assert_eq!(exp, &public_exponent, "Public exponent mismatch"); + } + } + + // Test 2: Get multiple attributes from private key + let priv_attributes_to_check = vec![ + AttributeType::Class, + AttributeType::KeyType, + AttributeType::Token, + AttributeType::Private, + AttributeType::Sign, + ]; + + let priv_attrs = session.get_attributes(private, &priv_attributes_to_check)?; + + assert!(!priv_attrs.is_empty(), "No private key attributes returned"); + + // Test 3: Single attribute with known fixed length (CK_ULONG) + let single_fixed = vec![AttributeType::KeyType]; + let attrs_single_fixed = session.get_attributes(public, &single_fixed)?; + assert_eq!( + attrs_single_fixed.len(), + 1, + "Should return exactly 1 attribute" + ); + assert!( + matches!(attrs_single_fixed[0], Attribute::KeyType(_)), + "Should be KeyType attribute" + ); + + // Test 4: Single attribute with variable length + let single_variable = vec![AttributeType::Modulus]; + let attrs_single_variable = session.get_attributes(public, &single_variable)?; + assert_eq!( + attrs_single_variable.len(), + 1, + "Should return exactly 1 attribute" + ); + assert!( + matches!(attrs_single_variable[0], Attribute::Modulus(_)), + "Should be Modulus attribute" + ); + + // Test 5: Single attribute that doesn't exist for this object (EC point for RSA key) + let single_invalid = vec![AttributeType::EcPoint]; + let attrs_single_invalid = session.get_attributes(public, &single_invalid)?; + assert_eq!( + attrs_single_invalid.len(), + 0, + "Should return 0 attributes for invalid attribute type" + ); + + // delete keys + session.destroy_object(public)?; + session.destroy_object(private)?; + + Ok(()) +} diff --git a/cryptoki/tests/common/mod.rs b/cryptoki/tests/common/mod.rs index e7968542..0d2b7f64 100644 --- a/cryptoki/tests/common/mod.rs +++ b/cryptoki/tests/common/mod.rs @@ -17,14 +17,24 @@ fn get_pkcs11_path() -> String { .unwrap_or_else(|_| "/usr/local/lib/softhsm/libsofthsm2.so".to_string()) } +// Used to simulate different library behaviors. +// for SoftHSM, just create the environment variable TEST_PRETEND_LIBRARY with "softhsm" +// This is used to interface a shim library during testing, while appearing to be using the real library. +#[allow(dead_code)] +pub fn get_pretend_library() -> String { + env::var("TEST_PRETEND_LIBRARY") + .unwrap_or_else(|_| "".to_string()) + .to_lowercase() +} + #[allow(dead_code)] pub fn is_softhsm() -> bool { - get_pkcs11_path().contains("softhsm") + get_pretend_library() == "softhsm" || get_pkcs11_path().contains("softhsm") } #[allow(dead_code)] pub fn is_kryoptic() -> bool { - get_pkcs11_path().contains("kryoptic") + get_pretend_library() == "kryoptic" || get_pkcs11_path().contains("kryoptic") } #[allow(dead_code)]