diff --git a/.env.dev b/.env.dev index 029284b..dc7a6f2 100644 --- a/.env.dev +++ b/.env.dev @@ -1,6 +1,9 @@ # [ad_server] TRUSTED_SERVER__AD_SERVER__AD_PARTNER_URL=http://127.0.0.1:10180 +# [publisher] +TRUSTED_SERVER__PUBLISHER__ORIGIN_URL=http://localhost:9090 + # [synthetic] TRUSTED_SERVER__SYNTHETIC__COUNTER_STORE=counter_store TRUSTED_SERVER__SYNTHETIC__OPID_STORE=opid_store diff --git a/Cargo.lock b/Cargo.lock index ef78534..05dddf7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -141,9 +141,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.5.0" +version = "8.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.5.1" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -274,6 +274,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crunchy" version = "0.2.4" @@ -366,13 +375,34 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl 1.0.0", +] + [[package]] name = "derive_more" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" dependencies = [ - "derive_more-impl", + "derive_more-impl 2.0.1", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", + "unicode-xid", ] [[package]] @@ -544,6 +574,16 @@ dependencies = [ "log", ] +[[package]] +name = "flate2" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1605,9 +1645,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.9.0" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f271e09bde39ab52250160a67e88577e0559ad77e9085de6e9051a2c4353f8f8" +checksum = "ed0aee96c12fa71097902e0bb061a5e1ebd766a6636bb605ba401c45c1650eac" dependencies = [ "indexmap", "serde", @@ -1629,18 +1669,18 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5c1c469eda89749d2230d8156a5969a69ffe0d6d01200581cdc6110674d293e" +checksum = "97200572db069e74c512a14117b296ba0a80a30123fbbb5aa1f4a348f639ca30" dependencies = [ "winnow", ] [[package]] name = "toml_writer" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b679217f2848de74cabd3e8fc5e6d66f40b7da40f8e1954d92054d9010690fd5" +checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64" [[package]] name = "trusted-server-common" @@ -1650,9 +1690,11 @@ dependencies = [ "chrono", "config", "cookie", - "derive_more", + "derive_more 1.0.0", + "derive_more 2.0.1", "error-stack", "fastly", + "flate2", "futures", "handlebars", "hex", @@ -1965,9 +2007,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74c7b26e3480b707944fc872477815d29a8e429d2f93a1ce000f5fa84a15cbcd" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 4e07985..6661fdf 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -9,13 +9,14 @@ publish = false license = "Apache-2.0" [dependencies] -brotli = "3.3" +brotli = "8.0" chrono = "0.4" config = "0.15.11" cookie = "0.18.1" derive_more = { version = "2.0", features = ["display", "error"] } error-stack = "0.5" fastly = "0.11.5" +flate2 = "1.0" futures = "0.3" handlebars = "6.3.2" hex = "0.4.3" @@ -35,10 +36,11 @@ urlencoding = "2.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.91" config = "0.15.11" -derive_more = { version = "2.0", features = ["display", "error"] } +derive_more = { version = "1.0", features = ["display", "error"] } error-stack = "0.5" http = "1.3.1" toml = "0.9.0" +url = "2.4.1" [dev-dependencies] regex = "1.1.1" diff --git a/crates/common/src/error.rs b/crates/common/src/error.rs index 75cf498..3527068 100644 --- a/crates/common/src/error.rs +++ b/crates/common/src/error.rs @@ -19,7 +19,15 @@ pub enum TrustedServerError { #[display("Configuration error: {message}")] Configuration { message: String }, + /// GAM (Google Ad Manager) integration error. + #[display("GAM error: {message}")] + Gam { message: String }, + /// GDPR consent handling error. + #[display("GDPR consent error: {message}")] + GdprConsent { message: String }, + /// The synthetic secret key is using the insecure default value. + #[display("Synthetic secret key is set to the default value - this is insecure")] InsecureSecretKey, @@ -31,30 +39,26 @@ pub enum TrustedServerError { #[display("Invalid HTTP header value: {message}")] InvalidHeaderValue { message: String }, + /// Key-value store operation failed. + #[display("KV store error: {store_name} - {message}")] + KvStore { store_name: String, message: String }, + + /// Prebid integration error. + #[display("Prebid error: {message}")] + Prebid { message: String }, + + /// Proxy error. + #[display("Proxy error: {message}")] + Proxy { message: String }, + /// Settings parsing or validation failed. #[display("Settings error: {message}")] Settings { message: String }, - /// GAM (Google Ad Manager) integration error. - #[display("GAM error: {message}")] - Gam { message: String }, - - /// GDPR consent handling error. - #[display("GDPR consent error: {message}")] - GdprConsent { message: String }, - /// Synthetic ID generation or validation failed. #[display("Synthetic ID error: {message}")] SyntheticId { message: String }, - /// Prebid integration error. - #[display("Prebid error: {message}")] - Prebid { message: String }, - - /// Key-value store operation failed. - #[display("KV store error: {store_name} - {message}")] - KvStore { store_name: String, message: String }, - /// Template rendering error. #[display("Template error: {message}")] Template { message: String }, @@ -76,14 +80,15 @@ impl IntoHttpResponse for TrustedServerError { fn status_code(&self) -> StatusCode { match self { Self::Configuration { .. } | Self::Settings { .. } => StatusCode::INTERNAL_SERVER_ERROR, - Self::InsecureSecretKey => StatusCode::INTERNAL_SERVER_ERROR, - Self::InvalidUtf8 { .. } => StatusCode::BAD_REQUEST, - Self::InvalidHeaderValue { .. } => StatusCode::BAD_REQUEST, Self::Gam { .. } => StatusCode::BAD_GATEWAY, Self::GdprConsent { .. } => StatusCode::BAD_REQUEST, - Self::SyntheticId { .. } => StatusCode::INTERNAL_SERVER_ERROR, - Self::Prebid { .. } => StatusCode::BAD_GATEWAY, + Self::InsecureSecretKey => StatusCode::INTERNAL_SERVER_ERROR, + Self::InvalidHeaderValue { .. } => StatusCode::BAD_REQUEST, + Self::InvalidUtf8 { .. } => StatusCode::BAD_REQUEST, Self::KvStore { .. } => StatusCode::SERVICE_UNAVAILABLE, + Self::Prebid { .. } => StatusCode::BAD_GATEWAY, + Self::Proxy { .. } => StatusCode::BAD_GATEWAY, + Self::SyntheticId { .. } => StatusCode::INTERNAL_SERVER_ERROR, Self::Template { .. } => StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/crates/common/src/gam.rs b/crates/common/src/gam.rs index 5723366..1e78626 100644 --- a/crates/common/src/gam.rs +++ b/crates/common/src/gam.rs @@ -1,6 +1,3 @@ -use crate::error::TrustedServerError; -use crate::gdpr::get_consent_from_request; -use crate::settings::Settings; use error_stack::Report; use fastly::http::{header, Method, StatusCode}; use fastly::{Request, Response}; @@ -9,6 +6,11 @@ use std::collections::HashMap; use std::io::Read; use uuid::Uuid; +use crate::error::TrustedServerError; +use crate::gdpr::get_consent_from_request; +use crate::settings::Settings; +use crate::templates::GAM_TEST_TEMPLATE; + /// GAM request builder for server-side ad requests pub struct GamRequest { pub publisher_id: String, @@ -1060,6 +1062,16 @@ pub async fn handle_gam_asset( } } +pub fn handle_gam_test_page( + _settings: &Settings, + _req: Request, +) -> Result> { + Ok(Response::from_status(StatusCode::OK) + .with_body(GAM_TEST_TEMPLATE) + .with_header(header::CONTENT_TYPE, "text/html") + .with_header("x-compress-hint", "on")) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index a3c28ad..430ac37 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -15,6 +15,7 @@ //! - [`prebid`]: Prebid integration and real-time bidding support //! - [`privacy`]: Privacy utilities and helpers //! - [`settings`]: Configuration management and validation +//! - [`streaming_replacer`]: Streaming URL replacement for large responses //! - [`synthetic`]: Synthetic ID generation using HMAC //! - [`templates`]: Handlebars template handling //! - [`test_support`]: Testing utilities and mocks @@ -34,6 +35,7 @@ pub mod privacy; pub mod publisher; pub mod settings; pub mod settings_data; +pub mod streaming_replacer; pub mod synthetic; pub mod templates; pub mod test_support; diff --git a/crates/common/src/publisher.rs b/crates/common/src/publisher.rs index 20b0512..cdcf809 100644 --- a/crates/common/src/publisher.rs +++ b/crates/common/src/publisher.rs @@ -1,6 +1,11 @@ -use error_stack::Report; +use brotli::enc::{writer::CompressorWriter, BrotliEncoderParams}; +use brotli::Decompressor; +use error_stack::{Report, ResultExt}; use fastly::http::{header, StatusCode}; -use fastly::{Request, Response}; +use fastly::{Body, Request, Response}; +use flate2::read::{GzDecoder, ZlibDecoder}; +use flate2::write::{GzEncoder, ZlibEncoder}; +use flate2::Compression; use crate::constants::{ HEADER_SYNTHETIC_FRESH, HEADER_SYNTHETIC_TRUSTED_SERVER, HEADER_X_COMPRESS_HINT, @@ -12,9 +17,69 @@ use crate::error::TrustedServerError; use crate::gdpr::{get_consent_from_request, GdprConsent}; use crate::geo::get_dma_code; use crate::settings::Settings; +use crate::streaming_replacer::{create_url_replacer, stream_process}; use crate::synthetic::{generate_synthetic_id, get_or_generate_synthetic_id}; use crate::templates::{EDGEPUBS_TEMPLATE, HTML_TEMPLATE}; +/// Detects the request scheme (HTTP or HTTPS) using Fastly SDK methods and headers. +/// +/// Tries multiple methods in order of reliability: +/// 1. Fastly SDK TLS detection methods (most reliable) +/// 2. Forwarded header (RFC 7239) +/// 3. X-Forwarded-Proto header +/// 4. Fastly-SSL header (least reliable, can be spoofed) +/// 5. Default to HTTP +fn detect_request_scheme(req: &Request) -> String { + // 1. First try Fastly SDK's built-in TLS detection methods + // These are the most reliable as they check the actual connection + if let Some(tls_protocol) = req.get_tls_protocol() { + // If we have a TLS protocol, the connection is definitely HTTPS + log::debug!("TLS protocol detected: {}", tls_protocol); + return "https".to_string(); + } + + // Also check TLS cipher - if present, connection is HTTPS + if req.get_tls_cipher_openssl_name().is_some() { + log::debug!("TLS cipher detected, using HTTPS"); + return "https".to_string(); + } + + // 2. Try the Forwarded header (RFC 7239) + if let Some(forwarded) = req.get_header("forwarded") { + if let Ok(forwarded_str) = forwarded.to_str() { + // Parse the Forwarded header + // Format: Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 + if forwarded_str.contains("proto=https") { + return "https".to_string(); + } else if forwarded_str.contains("proto=http") { + return "http".to_string(); + } + } + } + + // 3. Try X-Forwarded-Proto header + if let Some(proto) = req.get_header("x-forwarded-proto") { + if let Ok(proto_str) = proto.to_str() { + let proto_lower = proto_str.to_lowercase(); + if proto_lower == "https" || proto_lower == "http" { + return proto_lower; + } + } + } + + // 4. Check Fastly-SSL header (can be spoofed by clients, use as last resort) + if let Some(ssl) = req.get_header("fastly-ssl") { + if let Ok(ssl_str) = ssl.to_str() { + if ssl_str == "1" || ssl_str.to_lowercase() == "true" { + return "https".to_string(); + } + } + } + + // Default to HTTP (changed from HTTPS based on your settings file) + "http".to_string() +} + /// Handles the main page request. /// /// Serves the main page with synthetic ID generation and ad integration. @@ -125,6 +190,245 @@ pub fn handle_main_page( Ok(response) } +/// Process response body in streaming fashion with compression preservation +fn process_response_streaming( + body: Body, + content_encoding: &str, + origin_host: &str, + origin_url: &str, + request_host: &str, + request_scheme: &str, +) -> Result> { + const CHUNK_SIZE: usize = 8192; // 8KB chunks + + // Create the streaming replacer for URL replacements + let mut replacer = create_url_replacer(origin_host, origin_url, request_host, request_scheme); + + // Create output body + let mut output_body = Body::new(); + + // Determine if content needs decompression/recompression + let is_compressed = matches!(content_encoding, "gzip" | "deflate" | "br"); + + if is_compressed { + // For compressed content, we stream through: + // 1. Decompress chunks + // 2. Process them + // 3. Recompress and write to output + + log::info!( + "Processing compressed content with encoding: {}", + content_encoding + ); + + match content_encoding { + "gzip" => { + // Create gzip decompressor + let decoder = GzDecoder::new(body); + // Create gzip compressor + let mut encoder = GzEncoder::new(output_body, Compression::default()); + + // Stream through the pipeline + stream_process(decoder, &mut encoder, &mut replacer, CHUNK_SIZE).map_err(|e| { + Report::new(TrustedServerError::Proxy { + message: format!("Failed to process stream: {}", e), + }) + })?; + + // Finish compression and get the output body + output_body = encoder.finish().change_context(TrustedServerError::Proxy { + message: "Failed to finish gzip compression".to_string(), + })?; + } + "deflate" => { + // Create deflate decompressor + let decoder = ZlibDecoder::new(body); + // Create deflate compressor + let mut encoder = ZlibEncoder::new(output_body, Compression::default()); + + // Stream through the pipeline + stream_process(decoder, &mut encoder, &mut replacer, CHUNK_SIZE).map_err(|e| { + Report::new(TrustedServerError::Proxy { + message: format!("Failed to process stream: {}", e), + }) + })?; + + // Finish compression and get the output body + output_body = encoder.finish().change_context(TrustedServerError::Proxy { + message: "Failed to finish deflate compression".to_string(), + })?; + } + "br" => { + // Create Brotli decompressor + let decoder = Decompressor::new(body, 4096); // 4KB buffer + + // Create Brotli compressor with reasonable parameters + // Quality 4 gives good balance of speed and compression + let params = BrotliEncoderParams { + quality: 4, + lgwin: 22, // 4MB window + ..Default::default() + }; + + // Create Brotli compressor writer + let mut encoder = CompressorWriter::with_params(output_body, 4096, ¶ms); + + // Stream through the pipeline + stream_process(decoder, &mut encoder, &mut replacer, CHUNK_SIZE).map_err(|e| { + Report::new(TrustedServerError::Proxy { + message: format!("Failed to process Brotli stream: {}", e), + }) + })?; + + // Finish compression and get the output body + // Note: into_inner() returns the inner writer (Body), not a Result + output_body = encoder.into_inner(); + } + _ => unreachable!(), + } + } else { + // For uncompressed content, we can truly stream + log::info!("Processing uncompressed content"); + + // Stream directly from body to output + stream_process(body, &mut output_body, &mut replacer, CHUNK_SIZE).map_err(|e| { + Report::new(TrustedServerError::Proxy { + message: format!("Failed to process stream: {}", e), + }) + })?; + } + + log::info!("Streaming processing complete"); + Ok(output_body) +} + +/// Proxies requests to the publisher's origin server. +/// +/// This function forwards incoming requests to the configured origin URL, +/// preserving headers and request body. It's used as a fallback for routes +/// not explicitly handled by the trusted server. +/// +/// # Errors +/// +/// Returns a [`TrustedServerError`] if: +/// - The proxy request fails +/// - The origin backend is unreachable +pub fn handle_publisher_request( + settings: &Settings, + mut req: Request, +) -> Result> { + log::info!("Proxying request to publisher_origin"); + + // Extract the request host from the incoming request + let request_host = req + .get_header(header::HOST) + .map(|h| h.to_str().unwrap_or_default()) + .unwrap_or_default() + .to_string(); + + // Detect the request scheme using multiple methods + let request_scheme = detect_request_scheme(&req); + + // Log detection details for debugging + log::info!( + "Scheme detection - TLS Protocol: {:?}, TLS Cipher: {:?}, Forwarded: {:?}, X-Forwarded-Proto: {:?}, Fastly-SSL: {:?}, Result: {}", + req.get_tls_protocol(), + req.get_tls_cipher_openssl_name(), + req.get_header("forwarded"), + req.get_header("x-forwarded-proto"), + req.get_header("fastly-ssl"), + request_scheme + ); + + log::info!("Request host: {}, scheme: {}", request_host, request_scheme); + + // Extract host from the origin_url using the Publisher's origin_host method + let origin_host = settings.publisher.origin_host(); + + log::info!("Setting host header to: {}", origin_host); + req.set_header("host", &origin_host); + + // Send the request to the origin backend + let mut response = req + .send(&settings.publisher.origin_backend) + .change_context(TrustedServerError::Proxy { + message: "Failed to proxy request".to_string(), + })?; + + // Log all response headers for debugging + log::info!("Response headers:"); + for (name, value) in response.get_headers() { + log::info!(" {}: {:?}", name, value); + } + + // Check if the response has a text-based content type that we should process + let content_type = response + .get_header(header::CONTENT_TYPE) + .map(|h| h.to_str().unwrap_or_default()) + .unwrap_or_default(); + + let should_process = content_type.contains("text/") + || content_type.contains("application/javascript") + || content_type.contains("application/json"); + + if should_process && !request_host.is_empty() { + // Check if the response is compressed + let content_encoding = response + .get_header(header::CONTENT_ENCODING) + .map(|h| h.to_str().unwrap_or_default()) + .unwrap_or_default() + .to_lowercase(); + + // Log response details for debugging + log::info!( + "Processing response - Content-Type: {}, Content-Encoding: {}, Request Host: {}, Origin Host: {}", + content_type, content_encoding, request_host, origin_host + ); + + // Take the response body for streaming processing + let body = response.take_body(); + + // Process the body using streaming approach + match process_response_streaming( + body, + &content_encoding, + &origin_host, + &settings.publisher.origin_url, + &request_host, + &request_scheme, + ) { + Ok(processed_body) => { + // Set the processed body back + response.set_body(processed_body); + + // Remove Content-Length as the size has likely changed + response.remove_header(header::CONTENT_LENGTH); + + // Keep Content-Encoding header since we're returning compressed content + log::info!( + "Preserved Content-Encoding: {} for compressed response", + content_encoding + ); + + log::info!("Completed streaming processing of response body"); + } + Err(e) => { + log::error!("Failed to process response body: {:?}", e); + // Return an error response + return Err(e); + } + } + } else { + log::info!( + "Skipping response processing - should_process: {}, request_host: '{}'", + should_process, + request_host + ); + } + + Ok(response) +} + /// Handles the EdgePubs page request. /// /// Serves the EdgePubs landing page with integrated ad slots. @@ -181,3 +485,340 @@ pub fn handle_edgepubs_page( Ok(response) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_support::tests::create_test_settings; + use fastly::http::Method; + + #[test] + fn test_detect_request_scheme() { + // Note: In tests, we can't mock the TLS methods on Request, so we test header fallbacks + + // Test Forwarded header with HTTPS + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("forwarded", "for=192.0.2.60;proto=https;by=203.0.113.43"); + assert_eq!(detect_request_scheme(&req), "https"); + + // Test Forwarded header with HTTP + let mut req = Request::new(Method::GET, "http://test.example.com/page"); + req.set_header("forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43"); + assert_eq!(detect_request_scheme(&req), "http"); + + // Test X-Forwarded-Proto with HTTPS + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("x-forwarded-proto", "https"); + assert_eq!(detect_request_scheme(&req), "https"); + + // Test X-Forwarded-Proto with HTTP + let mut req = Request::new(Method::GET, "http://test.example.com/page"); + req.set_header("x-forwarded-proto", "http"); + assert_eq!(detect_request_scheme(&req), "http"); + + // Test Fastly-SSL header + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("fastly-ssl", "1"); + assert_eq!(detect_request_scheme(&req), "https"); + + // Test default to HTTP when no headers present + let req = Request::new(Method::GET, "https://test.example.com/page"); + assert_eq!(detect_request_scheme(&req), "http"); + + // Test priority: Forwarded takes precedence over X-Forwarded-Proto + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("forwarded", "proto=https"); + req.set_header("x-forwarded-proto", "http"); + assert_eq!(detect_request_scheme(&req), "https"); + } + + #[test] + fn test_handle_publisher_request_extracts_headers() { + // Test that the function correctly extracts host and scheme from request headers + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("host", "test.example.com"); + req.set_header("x-forwarded-proto", "https"); + + // Extract headers like the function does + let request_host = req + .get_header("host") + .map(|h| h.to_str().unwrap_or_default()) + .unwrap_or_default() + .to_string(); + + let request_scheme = req + .get_header("x-forwarded-proto") + .and_then(|h| h.to_str().ok()) + .unwrap_or("https") + .to_string(); + + assert_eq!(request_host, "test.example.com"); + assert_eq!(request_scheme, "https"); + } + + #[test] + fn test_handle_publisher_request_default_https_scheme() { + // Test default HTTPS when x-forwarded-proto is missing + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("host", "test.example.com"); + // No x-forwarded-proto header + + let request_scheme = req + .get_header("x-forwarded-proto") + .and_then(|h| h.to_str().ok()) + .unwrap_or("https"); + + assert_eq!(request_scheme, "https"); + } + + #[test] + fn test_handle_publisher_request_http_scheme() { + // Test HTTP scheme detection + let mut req = Request::new(Method::GET, "http://test.example.com/page"); + req.set_header("host", "test.example.com"); + req.set_header("x-forwarded-proto", "http"); + + let request_scheme = req + .get_header("x-forwarded-proto") + .and_then(|h| h.to_str().ok()) + .unwrap_or("https"); + + assert_eq!(request_scheme, "http"); + } + + #[test] + fn test_content_type_detection() { + // Test which content types should be processed + let test_cases = vec![ + ("text/html", true), + ("text/html; charset=utf-8", true), + ("text/css", true), + ("text/javascript", true), + ("application/javascript", true), + ("application/json", true), + ("application/json; charset=utf-8", true), + ("image/jpeg", false), + ("image/png", false), + ("application/pdf", false), + ("video/mp4", false), + ("application/octet-stream", false), + ]; + + for (content_type, should_process) in test_cases { + let result = content_type.contains("text/html") + || content_type.contains("text/css") + || content_type.contains("text/javascript") + || content_type.contains("application/javascript") + || content_type.contains("application/json"); + + assert_eq!( + result, should_process, + "Content-Type '{}' should_process: expected {}, got {}", + content_type, should_process, result + ); + } + } + + #[test] + fn test_handle_main_page_gdpr_consent() { + let settings = create_test_settings(); + let req = Request::new(Method::GET, "https://example.com/"); + + // Without GDPR consent, tracking should be disabled + let response = handle_main_page(&settings, req).unwrap(); + assert_eq!(response.get_status(), StatusCode::OK); + // Note: Would need to verify response body contains disabled tracking + } + + #[test] + fn test_publisher_origin_host_extraction() { + let settings = create_test_settings(); + let origin_host = settings.publisher.origin_host(); + assert_eq!(origin_host, "origin.test-publisher.com"); + + // Test with port + let mut settings_with_port = create_test_settings(); + settings_with_port.publisher.origin_url = "origin.test-publisher.com:8080".to_string(); + assert_eq!( + settings_with_port.publisher.origin_host(), + "origin.test-publisher.com:8080" + ); + } + + #[test] + fn test_invalid_utf8_handling() { + // Test that invalid UTF-8 bytes are handled gracefully + let invalid_utf8_bytes = vec![0xFF, 0xFE, 0xFD]; // Invalid UTF-8 sequence + + // Verify these bytes cannot be converted to a valid UTF-8 string + assert!(String::from_utf8(invalid_utf8_bytes.clone()).is_err()); + + // In the actual function, invalid UTF-8 would be passed through unchanged + // This test verifies our approach is sound + } + + #[test] + fn test_utf8_conversion_edge_cases() { + // Test various UTF-8 edge cases + let test_cases = vec![ + // Valid UTF-8 with special characters + (vec![0xE2, 0x98, 0x83], true), // β˜ƒ (snowman) + (vec![0xF0, 0x9F, 0x98, 0x80], true), // πŸ˜€ (emoji) + // Invalid UTF-8 sequences + (vec![0xFF, 0xFE], false), // Invalid start byte + (vec![0xC0, 0x80], false), // Overlong encoding + (vec![0xED, 0xA0, 0x80], false), // Surrogate half + ]; + + for (bytes, should_be_valid) in test_cases { + let result = String::from_utf8(bytes.clone()); + assert_eq!( + result.is_ok(), + should_be_valid, + "UTF-8 validation failed for bytes: {:?}", + bytes + ); + } + } + + #[test] + fn test_streaming_compressed_content() { + use flate2::write::GzEncoder; + use flate2::Compression; + use std::io::Write; + + // Create some HTML content with origin URLs + let original_content = r#" + + + Link + "#; + + // Compress the content + let mut compressed = Vec::new(); + { + let mut encoder = GzEncoder::new(&mut compressed, Compression::default()); + encoder.write_all(original_content.as_bytes()).unwrap(); + encoder.finish().unwrap(); + } + + // Create a Body from compressed data + let body = Body::from(compressed); + + // Process the compressed body + let result = process_response_streaming( + body, + "gzip", + "origin.example.com", + "https://origin.example.com", + "test.example.com", + "https", + ); + + assert!(result.is_ok()); + let processed_body = result.unwrap(); + + // The body should still be compressed + // In a real test, we'd decompress and verify the content + // For now, just check that we got a body back + let bytes = processed_body.into_bytes(); + assert!(!bytes.is_empty()); + + // Decompress to verify content was transformed + use flate2::read::GzDecoder; + use std::io::Read; + let mut decoder = GzDecoder::new(&bytes[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + // Verify URLs were replaced + assert!(decompressed.contains("https://test.example.com/style.css")); + assert!(decompressed.contains("https://test.example.com/app.js")); + assert!(decompressed.contains("https://test.example.com/page")); + assert!(!decompressed.contains("origin.example.com")); + } + + #[test] + fn test_streaming_brotli_content() { + use brotli::enc::writer::CompressorWriter; + use brotli::enc::BrotliEncoderParams; + use std::io::Write; + + // Create some HTML content with origin URLs + let original_content = r#" + + + Link + "#; + + // Compress the content with Brotli + let mut compressed = Vec::new(); + { + let params = BrotliEncoderParams { + quality: 4, + lgwin: 22, + ..Default::default() + }; + let mut encoder = CompressorWriter::with_params(&mut compressed, 4096, ¶ms); + encoder.write_all(original_content.as_bytes()).unwrap(); + // encoder is dropped here, which finishes the compression + } + + // Create a Body from compressed data + let body = Body::from(compressed); + + // Process the compressed body + let result = process_response_streaming( + body, + "br", + "origin.example.com", + "https://origin.example.com", + "test.example.com", + "https", + ); + + assert!(result.is_ok()); + let processed_body = result.unwrap(); + + // The body should still be compressed + // In a real test, we'd decompress and verify the content + // For now, just check that we got a body back + let bytes = processed_body.into_bytes(); + assert!(!bytes.is_empty()); + + // Decompress to verify content was transformed + use brotli::Decompressor; + use std::io::Read; + let mut decoder = Decompressor::new(&bytes[..], 4096); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + // Verify URLs were replaced + assert!(decompressed.contains("https://test.example.com/style.css")); + assert!(decompressed.contains("https://test.example.com/app.js")); + assert!(decompressed.contains("https://test.example.com/page")); + assert!(!decompressed.contains("origin.example.com")); + } + + #[test] + fn test_content_encoding_detection() { + // Test that we properly handle responses with various content encodings + let test_encodings = vec!["gzip", "deflate", "br", "identity", ""]; + + for encoding in test_encodings { + let mut req = Request::new(Method::GET, "https://test.example.com/page"); + req.set_header("accept-encoding", "gzip, deflate, br"); + + if !encoding.is_empty() { + req.set_header("content-encoding", encoding); + } + + let content_encoding = req + .get_header("content-encoding") + .map(|h| h.to_str().unwrap_or_default()) + .unwrap_or_default(); + + assert_eq!(content_encoding, encoding); + } + } +} diff --git a/crates/common/src/settings.rs b/crates/common/src/settings.rs index 8f65f78..1cbbbb2 100644 --- a/crates/common/src/settings.rs +++ b/crates/common/src/settings.rs @@ -3,6 +3,7 @@ use core::str; use config::{Config, Environment, File, FileFormat}; use error_stack::{Report, ResultExt}; use serde::{Deserialize, Serialize}; +use url::Url; use crate::error::TrustedServerError; @@ -19,9 +20,38 @@ pub struct AdServer { pub struct Publisher { pub domain: String, pub cookie_domain: String, + pub origin_backend: String, pub origin_url: String, } +impl Publisher { + /// Extracts the host (including port if present) from the origin_url. + /// + /// # Examples + /// + /// ``` + /// # use trusted_server_common::settings::Publisher; + /// let publisher = Publisher { + /// domain: "example.com".to_string(), + /// cookie_domain: ".example.com".to_string(), + /// origin_url: "https://origin.example.com:8080".to_string(), + /// }; + /// assert_eq!(publisher.origin_host(), "origin.example.com:8080"); + /// ``` + #[allow(dead_code)] + pub fn origin_host(&self) -> String { + Url::parse(&self.origin_url) + .ok() + .and_then(|url| { + url.host_str().map(|host| match url.port() { + Some(port) => format!("{}:{}", host, port), + None => host.to_string(), + }) + }) + .unwrap_or_else(|| self.origin_url.clone()) + } +} + #[derive(Debug, Default, Deserialize, Serialize)] pub struct Prebid { pub server_url: String, @@ -67,6 +97,11 @@ pub struct Partners { pub prebid: Option, } +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct Experimental { + pub enable_edge_pub: bool, +} + #[derive(Debug, Default, Deserialize, Serialize)] pub struct Settings { pub ad_server: AdServer, @@ -75,6 +110,7 @@ pub struct Settings { pub gam: Gam, pub synthetic: Synthetic, pub partners: Option, + pub experimental: Option, } #[allow(unused)] @@ -309,4 +345,61 @@ mod tests { }, ); } + + #[test] + fn test_publisher_origin_host() { + // Test with full URL including port + let publisher = Publisher { + domain: "example.com".to_string(), + cookie_domain: ".example.com".to_string(), + origin_backend: "publisher_origin".to_string(), + origin_url: "https://origin.example.com:8080".to_string(), + }; + assert_eq!(publisher.origin_host(), "origin.example.com:8080"); + + // Test with URL without port (default HTTPS port) + let publisher = Publisher { + domain: "example.com".to_string(), + cookie_domain: ".example.com".to_string(), + origin_backend: "publisher_origin".to_string(), + origin_url: "https://origin.example.com".to_string(), + }; + assert_eq!(publisher.origin_host(), "origin.example.com"); + + // Test with HTTP URL with explicit port + let publisher = Publisher { + domain: "example.com".to_string(), + cookie_domain: ".example.com".to_string(), + origin_backend: "publisher_origin".to_string(), + origin_url: "http://localhost:9090".to_string(), + }; + assert_eq!(publisher.origin_host(), "localhost:9090"); + + // Test with URL without protocol (fallback to original) + let publisher = Publisher { + domain: "example.com".to_string(), + cookie_domain: ".example.com".to_string(), + origin_backend: "publisher_origin".to_string(), + origin_url: "localhost:9090".to_string(), + }; + assert_eq!(publisher.origin_host(), "localhost:9090"); + + // Test with IPv4 address + let publisher = Publisher { + domain: "example.com".to_string(), + cookie_domain: ".example.com".to_string(), + origin_backend: "publisher_origin".to_string(), + origin_url: "http://192.168.1.1:8080".to_string(), + }; + assert_eq!(publisher.origin_host(), "192.168.1.1:8080"); + + // Test with IPv6 address + let publisher = Publisher { + domain: "example.com".to_string(), + cookie_domain: ".example.com".to_string(), + origin_backend: "publisher_origin".to_string(), + origin_url: "http://[::1]:8080".to_string(), + }; + assert_eq!(publisher.origin_host(), "[::1]:8080"); + } } diff --git a/crates/common/src/streaming_replacer.rs b/crates/common/src/streaming_replacer.rs new file mode 100644 index 0000000..3dc606e --- /dev/null +++ b/crates/common/src/streaming_replacer.rs @@ -0,0 +1,734 @@ +//! Generic streaming replacer for processing large content. +//! +//! This module provides functionality for replacing patterns in content +//! in streaming fashion, handling content that may be split across multiple chunks. + +use std::io::{self, Read, Write}; + +/// A replacement pattern configuration +#[derive(Debug, Clone)] +pub struct Replacement { + /// The string to find + pub find: String, + /// The string to replace it with + pub replace_with: String, +} + +/// A generic streaming replacer that processes content in chunks +pub struct StreamingReplacer { + /// List of replacements to apply + replacements: Vec, + // Buffer to handle partial matches at chunk boundaries + overlap_buffer: Vec, + // Maximum pattern length to determine overlap size + max_pattern_length: usize, +} + +impl StreamingReplacer { + /// Creates a new `StreamingReplacer` with the given replacements. + /// + /// # Arguments + /// + /// * `replacements` - List of string replacements to perform + pub fn new(replacements: Vec) -> Self { + // Calculate the maximum pattern length we need to buffer + let max_pattern_length = replacements.iter().map(|r| r.find.len()).max().unwrap_or(0); + + Self { + replacements, + overlap_buffer: Vec::with_capacity(max_pattern_length), + max_pattern_length, + } + } + + /// Creates a new `StreamingReplacer` with a single replacement. + /// + /// # Arguments + /// + /// * `find` - The string to find + /// * `replace_with` - The string to replace it with + pub fn new_single(find: &str, replace_with: &str) -> Self { + Self::new(vec![Replacement { + find: find.to_string(), + replace_with: replace_with.to_string(), + }]) + } + + /// Process a chunk of data and return the processed output + pub fn process_chunk(&mut self, chunk: &[u8], is_last_chunk: bool) -> Vec { + // Combine overlap buffer with new chunk + let mut combined = self.overlap_buffer.clone(); + combined.extend_from_slice(chunk); + + if combined.is_empty() { + return Vec::new(); + } + + // Determine how much content to process + let process_end_bytes = if is_last_chunk { + combined.len() + } else { + // To avoid splitting patterns, we need to be careful about where we cut. + // We want to keep at least (max_pattern_length - 1) bytes for overlap. + if combined.len() <= self.max_pattern_length { + // Not enough data to process safely + 0 + } else { + // Start with a safe boundary + let mut boundary = combined.len().saturating_sub(self.max_pattern_length - 1); + + // Check if we might be splitting a pattern at this boundary + // by looking for pattern starts near the boundary + let check_start = boundary.saturating_sub(self.max_pattern_length); + let check_end = (boundary + self.max_pattern_length).min(combined.len()); + + if let Ok(check_str) = std::str::from_utf8(&combined[check_start..check_end]) { + // Look for any pattern that would be split by our boundary + for replacement in &self.replacements { + if let Some(pos) = check_str.find(&replacement.find) { + let pattern_start = check_start + pos; + let pattern_end = pattern_start + replacement.find.len(); + + // If the pattern crosses our boundary, adjust the boundary + if pattern_start < boundary && pattern_end > boundary { + boundary = pattern_start; + break; + } + } + } + } + + boundary + } + }; + + if process_end_bytes == 0 { + // Not enough data to process yet + self.overlap_buffer = combined; + return Vec::new(); + } + + // Find a valid UTF-8 boundary at or before process_end_bytes + let mut adjusted_end_bytes = process_end_bytes; + while adjusted_end_bytes > 0 { + // Check if this is a valid UTF-8 boundary + if let Ok(s) = std::str::from_utf8(&combined[..adjusted_end_bytes]) { + // Valid UTF-8 up to this point, process it + let mut processed = s.to_string(); + + // Apply all replacements + for replacement in &self.replacements { + processed = processed.replace(&replacement.find, &replacement.replace_with); + } + + // Save the overlap for the next chunk + if !is_last_chunk { + self.overlap_buffer = combined[adjusted_end_bytes..].to_vec(); + } else { + self.overlap_buffer.clear(); + } + + return processed.into_bytes(); + } + adjusted_end_bytes -= 1; + } + + // This should never happen, but handle it gracefully + self.overlap_buffer = combined; + Vec::new() + } + + /// Reset the internal buffer (useful when reusing the replacer) + pub fn reset(&mut self) { + self.overlap_buffer.clear(); + } +} + +/// Process a stream through a StreamingReplacer +/// +/// This function reads from a source, processes chunks through the replacer, +/// and writes the results to the output. +/// +/// # Arguments +/// +/// * `reader` - The input stream to read from +/// * `writer` - The output stream to write to +/// * `replacer` - The streaming replacer to use for processing +/// * `chunk_size` - The size of chunks to read at a time +/// +/// # Returns +/// +/// Returns `Ok(())` on success, or an `io::Error` if reading/writing fails. +pub fn stream_process( + mut reader: R, + writer: &mut W, + replacer: &mut StreamingReplacer, + chunk_size: usize, +) -> io::Result<()> { + let mut buffer = vec![0u8; chunk_size]; + + loop { + match reader.read(&mut buffer)? { + 0 => { + // End of stream - process any remaining data + let final_chunk = replacer.process_chunk(&[], true); + if !final_chunk.is_empty() { + writer.write_all(&final_chunk)?; + } + break; + } + n => { + // Process this chunk + let processed = replacer.process_chunk(&buffer[..n], false); + if !processed.is_empty() { + writer.write_all(&processed)?; + } + } + } + } + + writer.flush()?; + Ok(()) +} + +/// Helper function to create a StreamingReplacer for URL replacements +pub fn create_url_replacer( + origin_host: &str, + origin_url: &str, + request_host: &str, + request_scheme: &str, +) -> StreamingReplacer { + let request_url = format!("{}://{}", request_scheme, request_host); + + let mut replacements = vec![ + // Replace full URLs first (more specific) + Replacement { + find: origin_url.to_string(), + replace_with: request_url.clone(), + }, + ]; + + // Also handle HTTP variant if origin is HTTPS + if origin_url.starts_with("https://") { + let http_origin_url = origin_url.replace("https://", "http://"); + replacements.push(Replacement { + find: http_origin_url, + replace_with: request_url.clone(), + }); + } + + // Replace protocol-relative URLs + replacements.push(Replacement { + find: format!("//{}", origin_host), + replace_with: format!("//{}", request_host), + }); + + // Replace host in various contexts + replacements.push(Replacement { + find: origin_host.to_string(), + replace_with: request_host.to_string(), + }); + + StreamingReplacer::new(replacements) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_streaming_replacer_basic() { + let mut replacer = + StreamingReplacer::new_single("https://origin.example.com", "https://test.example.com"); + + let input = b"Visit https://origin.example.com for more info"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + assert_eq!(result, "Visit https://test.example.com for more info"); + } + + #[test] + fn test_multiple_replacements() { + let replacements = vec![ + Replacement { + find: "foo".to_string(), + replace_with: "bar".to_string(), + }, + Replacement { + find: "hello".to_string(), + replace_with: "hi".to_string(), + }, + ]; + + let mut replacer = StreamingReplacer::new(replacements); + + let input = b"hello world, foo is foo"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + assert_eq!(result, "hi world, bar is bar"); + } + + #[test] + fn test_streaming_replacer_chunks() { + let mut replacer = + StreamingReplacer::new_single("https://origin.example.com", "https://test.example.com"); + + // Test that patterns split across chunks are handled correctly + let chunk1 = b"Visit https://origin.exam"; + let chunk2 = b"ple.com for more info"; + + let processed1 = replacer.process_chunk(chunk1, false); + let processed2 = replacer.process_chunk(chunk2, true); + + let result = String::from_utf8([processed1, processed2].concat()).unwrap(); + assert_eq!(result, "Visit https://test.example.com for more info"); + } + + #[test] + fn test_streaming_replacer_multiple_patterns() { + let replacements = vec![ + Replacement { + find: "https://origin.example.com".to_string(), + replace_with: "https://test.example.com".to_string(), + }, + Replacement { + find: "//origin.example.com".to_string(), + replace_with: "//test.example.com".to_string(), + }, + ]; + + let mut replacer = StreamingReplacer::new(replacements); + + let input = + b"Link and //origin.example.com/resource"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + assert!(result.contains("https://test.example.com")); + assert!(result.contains("//test.example.com/resource")); + } + + #[test] + fn test_streaming_replacer_edge_cases() { + let mut replacer = + StreamingReplacer::new_single("https://origin.example.com", "https://test.example.com"); + + // Empty chunk + let processed = replacer.process_chunk(b"", true); + assert!(processed.is_empty()); + + // Very small chunks + let chunks = [ + b"h".as_ref(), + b"t".as_ref(), + b"t".as_ref(), + b"p".as_ref(), + b"s".as_ref(), + b":".as_ref(), + b"/".as_ref(), + b"/".as_ref(), + b"origin.example.com".as_ref(), + ]; + + let mut result = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let is_last = i == chunks.len() - 1; + let processed = replacer.process_chunk(chunk, is_last); + result.extend(processed); + } + + let result_str = String::from_utf8(result).unwrap(); + assert_eq!(result_str, "https://test.example.com"); + } + + #[test] + fn test_url_replacer_comprehensive() { + let mut replacer = create_url_replacer( + "origin.example.com", + "https://origin.example.com", + "test.example.com", + "https", + ); + + // Test comprehensive URL replacement scenarios + let content = r#" + + Link + + + + + + + + + {"api": "https://origin.example.com/api", "host": "origin.example.com"} + "#; + + let processed = replacer.process_chunk(content.as_bytes(), true); + let result = String::from_utf8(processed).unwrap(); + + // Verify all patterns were replaced + assert!(result.contains("https://test.example.com/page")); + assert!(result.contains("https://test.example.com/image.jpg")); + assert!(result.contains("//test.example.com/script.js")); + assert!(result.contains(r#""api": "https://test.example.com/api""#)); + assert!(result.contains(r#""host": "test.example.com""#)); + + // Ensure no origin URLs remain + assert!(!result.contains("origin.example.com")); + } + + #[test] + fn test_url_replacer_with_port() { + let mut replacer = create_url_replacer( + "origin.example.com:8080", + "https://origin.example.com:8080", + "test.example.com:9090", + "https", + ); + + let content = + b"Visit https://origin.example.com:8080/api or //origin.example.com:8080/resource"; + let processed = replacer.process_chunk(content, true); + let result = String::from_utf8(processed).unwrap(); + + assert_eq!( + result, + "Visit https://test.example.com:9090/api or //test.example.com:9090/resource" + ); + } + + #[test] + fn test_url_replacer_mixed_protocols() { + let mut replacer = create_url_replacer( + "origin.example.com", + "https://origin.example.com", + "test.example.com", + "http", + ); + + let content = r#" + HTTPS Link + HTTP Link + + "#; + + let processed = replacer.process_chunk(content.as_bytes(), true); + let result = String::from_utf8(processed).unwrap(); + + // When request is HTTP, all URLs should be replaced with HTTP + assert!(result.contains("http://test.example.com")); + assert!(!result.contains("https://test.example.com")); + assert!(result.contains("//test.example.com/script.js")); + } + + #[test] + fn test_process_chunk_utf8_boundary() { + let mut replacer = + create_url_replacer("origin.com", "https://origin.com", "test.com", "https"); + + // Create content with multi-byte UTF-8 characters that could cause boundary issues + let content = "https://origin.com/test 思怙ᕏ桋试 https://origin.com/more".as_bytes(); + + // Process in small chunks to force potential boundary issues + let chunk_size = 20; + let mut result = Vec::new(); + + for (i, chunk) in content.chunks(chunk_size).enumerate() { + let is_last = i == content.chunks(chunk_size).count() - 1; + result.extend(replacer.process_chunk(chunk, is_last)); + } + + let result_str = String::from_utf8(result).unwrap(); + assert!(result_str.contains("https://test.com/test")); + assert!(result_str.contains("https://test.com/more")); + assert!(result_str.contains("思怙ᕏ桋试")); + } + + #[test] + fn test_process_chunk_boundary_in_multibyte_char() { + let mut replacer = + create_url_replacer("example.com", "https://example.com", "new.com", "https"); + + // Create a scenario where chunk boundary falls in the middle of a multi-byte character + let content = "https://example.com/fΓΈr/bΓ₯r/test".as_bytes(); + + // Split at byte 23, which should be in the middle of 'ΓΈ' (2-byte character) + let chunk1 = &content[..23]; + let chunk2 = &content[23..]; + + let mut result = Vec::new(); + result.extend(replacer.process_chunk(chunk1, false)); + result.extend(replacer.process_chunk(chunk2, true)); + + let result_str = String::from_utf8(result).unwrap(); + assert!(result_str.contains("https://new.com/fΓΈr/bΓ₯r/test")); + } + + #[test] + fn test_process_chunk_emoji_boundary() { + let mut replacer = + create_url_replacer("emoji.com", "https://emoji.com", "test.com", "https"); + + // Test with 4-byte emoji characters + let content = "https://emoji.com/test πŸŽ‰πŸŽŠπŸŽ‹ https://emoji.com/more".as_bytes(); + + // Process the entire content at once to verify it works + let all_at_once = replacer.process_chunk(content, true); + let expected = String::from_utf8(all_at_once).unwrap(); + assert!(expected.contains("https://test.com/test")); + assert!(expected.contains("https://test.com/more")); + } + + #[test] + fn test_process_chunk_large_chunks() { + let mut replacer = + create_url_replacer("example.com", "https://example.com", "test.com", "https"); + + // Test with content that won't have URLs split across chunks + let content = + "Visit https://example.com/page1 and then https://example.com/page2 for more info" + .as_bytes(); + + // Use large chunks to avoid splitting URLs + let chunk_size = 50; + let mut result = Vec::new(); + + for (i, chunk) in content.chunks(chunk_size).enumerate() { + let is_last = i == content.chunks(chunk_size).count() - 1; + result.extend(replacer.process_chunk(chunk, is_last)); + } + + let result_str = String::from_utf8(result).unwrap(); + assert!(result_str.contains("https://test.com/page1")); + assert!(result_str.contains("https://test.com/page2")); + } + + #[test] + fn test_process_chunk_utf8_boundary_small_chunks() { + let mut replacer = create_url_replacer("test.com", "https://test.com", "new.com", "https"); + + // Test with multi-byte characters and very small chunks to stress UTF-8 boundaries + let content = "Some text 思怙ᕏ桋试 more text with πŸŽ‰ emoji".as_bytes(); + + // Use very small chunks to force UTF-8 boundary handling + let chunk_size = 8; + let mut result = Vec::new(); + let chunks: Vec<_> = content.chunks(chunk_size).collect(); + + for (i, chunk) in chunks.iter().enumerate() { + let is_last = i == chunks.len() - 1; + result.extend(replacer.process_chunk(chunk, is_last)); + } + + let result_str = String::from_utf8(result).unwrap(); + // Just verify the content is preserved correctly + assert!(result_str.contains("思怙ᕏ桋试")); + assert!(result_str.contains("πŸŽ‰")); + } + + #[test] + fn test_generic_replacements() { + // Test replacing arbitrary strings + let replacements = vec![ + Replacement { + find: "color".to_string(), + replace_with: "colour".to_string(), + }, + Replacement { + find: "gray".to_string(), + replace_with: "grey".to_string(), + }, + ]; + + let mut replacer = StreamingReplacer::new(replacements); + + let input = b"The color is gray, not light gray."; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + assert_eq!(result, "The colour is grey, not light grey."); + } + + #[test] + fn test_pattern_priority() { + // Test that longer patterns are replaced first (order matters) + let replacements = vec![ + Replacement { + find: "hello world".to_string(), + replace_with: "greetings universe".to_string(), + }, + Replacement { + find: "hello".to_string(), + replace_with: "hi".to_string(), + }, + ]; + + let mut replacer = StreamingReplacer::new(replacements); + + let input = b"Say hello world and hello there!"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + // Note: Since we apply replacements in order, "hello world" gets replaced first + assert_eq!(result, "Say greetings universe and hi there!"); + } + + #[test] + fn test_overlapping_patterns() { + // Test handling of overlapping patterns + let replacements = vec![ + Replacement { + find: "abc".to_string(), + replace_with: "xyz".to_string(), + }, + Replacement { + find: "bcd".to_string(), + replace_with: "123".to_string(), + }, + ]; + + let mut replacer = StreamingReplacer::new(replacements); + + let input = b"abcdef"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + // "abc" gets replaced first, so "bcd" is no longer found + assert_eq!(result, "xyzdef"); + } + + #[test] + fn test_empty_replacement() { + // Test removing strings (replacing with empty string) + let mut replacer = StreamingReplacer::new_single("REMOVE_ME", ""); + + let input = b"Keep this REMOVE_ME but not this"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + assert_eq!(result, "Keep this but not this"); + } + + #[test] + fn test_case_sensitive_replacement() { + // Test that replacements are case-sensitive + let mut replacer = StreamingReplacer::new_single("Hello", "Hi"); + + let input = b"Hello world, hello there, HELLO!"; + let processed = replacer.process_chunk(input, true); + let result = String::from_utf8(processed).unwrap(); + + assert_eq!(result, "Hi world, hello there, HELLO!"); + } + + #[test] + fn test_special_characters_in_pattern() { + // Test replacing patterns with special regex characters + let replacements = vec![ + Replacement { + find: "cost: $10.99".to_string(), + replace_with: "price: €9.99".to_string(), + }, + Replacement { + find: "[TAG]".to_string(), + replace_with: "