diff --git a/Cargo.toml b/Cargo.toml index d7056c6..49e6187 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ rust-version = "1.78.0" reqwest = { version = "0.11.3", default-features = false, features = [ "rustls-tls", "blocking", + "json", + "gzip", ] } serde = { version = "1.0.125", features = ["derive"] } chrono = { version = "0.4.19", features = ["serde"] } @@ -19,12 +21,19 @@ serde_json = "1.0.64" semver = "1.0.24" derive_builder = "0.20.2" uuid = { version = "1.13.2", features = ["serde", "v7"] } +sha1 = "0.10" +regex = "1.10" +tokio = { version = "1", features = ["rt", "sync", "time", "macros"], optional = true } [dev-dependencies] dotenv = "0.15.0" ctor = "0.1.26" +tokio = { version = "1", features = ["full"] } +httpmock = "0.7" +serde_json = "1.0" +futures = "0.3" [features] default = ["async-client"] e2e-test = [] -async-client = [] +async-client = ["tokio"] diff --git a/README.md b/README.md index 3fd7966..7f1c13c 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,121 @@ posthog-rs = "0.3.7" ```rust let client = posthog_rs::client(env!("POSTHOG_API_KEY")); +// Capture events let mut event = posthog_rs::Event::new("test", "1234"); event.insert_prop("key1", "value1").unwrap(); event.insert_prop("key2", vec!["a", "b"]).unwrap(); client.capture(event).unwrap(); + +// Check feature flags +let is_enabled = client.is_feature_enabled( + "new-feature".to_string(), + "user-123".to_string(), + None, + None, + None, +).unwrap(); + +if is_enabled { + println!("Feature is enabled!"); +} +``` + +## Feature Flags + +The SDK now supports PostHog feature flags, allowing you to control feature rollout and run A/B tests. + +### Basic Usage + +```rust +use posthog_rs::{client, ClientOptions, FlagValue}; +use std::collections::HashMap; +use serde_json::json; + +let client = client(ClientOptions::from("your-api-key")); + +// Check if a feature is enabled +let is_enabled = client.is_feature_enabled( + "feature-key".to_string(), + "user-id".to_string(), + None, None, None +).unwrap(); + +// Get feature flag value (boolean or variant) +match client.get_feature_flag( + "feature-key".to_string(), + "user-id".to_string(), + None, None, None +).unwrap() { + Some(FlagValue::Boolean(enabled)) => println!("Flag is: {}", enabled), + Some(FlagValue::String(variant)) => println!("Variant: {}", variant), + None => println!("Flag is disabled"), +} +``` + +### With Properties + +```rust +// Include person properties for flag evaluation +let mut person_props = HashMap::new(); +person_props.insert("plan".to_string(), json!("enterprise")); +person_props.insert("country".to_string(), json!("US")); + +let flag = client.get_feature_flag( + "premium-feature".to_string(), + "user-id".to_string(), + None, + Some(person_props), + None +).unwrap(); +``` + +### With Groups (B2B) + +```rust +// For B2B apps with group-based flags +let mut groups = HashMap::new(); +groups.insert("company".to_string(), "company-123".to_string()); + +let mut group_props = HashMap::new(); +let mut company_props = HashMap::new(); +company_props.insert("size".to_string(), json!(500)); +group_props.insert("company".to_string(), company_props); + +let flag = client.get_feature_flag( + "b2b-feature".to_string(), + "user-id".to_string(), + Some(groups), + None, + Some(group_props) +).unwrap(); +``` + +### Get All Flags + +```rust +// Get all feature flags for a user +let response = client.get_feature_flags( + "user-id".to_string(), + None, None, None +).unwrap(); + +for (key, value) in response.feature_flags { + println!("Flag {}: {:?}", key, value); +} +``` + +### Feature Flag Payloads + +```rust +// Get additional data associated with a feature flag +let payload = client.get_feature_flag_payload( + "onboarding-flow".to_string(), + "user-id".to_string() +).unwrap(); + +if let Some(data) = payload { + println!("Payload: {}", data); +} ``` diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..31ec16d --- /dev/null +++ b/examples/README.md @@ -0,0 +1,93 @@ +# PostHog Rust SDK Examples + +This directory contains example applications demonstrating how to use the PostHog Rust SDK, particularly the feature flags functionality. + +## Running the Examples + +### 1. Feature Flags with Mock Data (No API Key Required) +This example demonstrates feature flag evaluation using local mock data: + +```bash +cargo run --example feature_flags_with_mock +``` + +This example shows: +- Simple percentage rollouts +- Property-based targeting (country, plan, age) +- Multivariate experiments with variants +- Local evaluation without API calls + +### 2. Feature Flags Demo Application +A complete e-commerce demo with interactive testing: + +```bash +# With a real PostHog API key +POSTHOG_API_KEY=your_api_key cargo run --example feature_flags_demo --all-features + +# Without API key (uses local fallbacks) +cargo run --example feature_flags_demo --all-features +``` + +Features demonstrated: +- New checkout flow (premium/enterprise plans) +- AI recommendations (multivariate test) +- Pricing experiments (based on lifetime value) +- Holiday themes (geographic targeting) +- Interactive testing mode + +### 3. Basic Feature Flags Example +Simple examples of all feature flag operations: + +```bash +POSTHOG_API_KEY=your_api_key cargo run --example feature_flags +``` + +Shows: +- Checking if flags are enabled +- Getting flag values and variants +- Using person and group properties +- Getting feature flag payloads + +## Testing Without a PostHog Account + +The `feature_flags_with_mock` example is perfect for testing the SDK without needing a PostHog account. It demonstrates: + +1. **Percentage Rollouts**: Flags enabled for a percentage of users +2. **Property Matching**: Target users based on properties like country, plan, age +3. **Multivariate Testing**: Split users into different variants +4. **Complex Conditions**: Combine multiple conditions with AND logic + +## Key Concepts + +### Feature Flag Types + +1. **Boolean Flags**: Simple on/off toggles + ```rust + FlagValue::Boolean(true) // enabled + FlagValue::Boolean(false) // disabled + ``` + +2. **Multivariate Flags**: Multiple variants for A/B/n testing + ```rust + FlagValue::String("control") + FlagValue::String("variant-a") + FlagValue::String("variant-b") + ``` + +### Evaluation Methods + +1. **Remote Evaluation**: Calls PostHog API for the latest flag values +2. **Local Evaluation**: Uses cached flag definitions for offline evaluation + +### Properties + +- **Person Properties**: User attributes (country, plan, age, etc.) +- **Group Properties**: Organization/team attributes for B2B apps + +## Common Use Cases + +1. **Feature Rollouts**: Gradually release features to users +2. **A/B Testing**: Test different variants to measure impact +3. **User Targeting**: Enable features for specific user segments +4. **Kill Switches**: Quickly disable problematic features +5. **Beta Programs**: Give early access to beta users \ No newline at end of file diff --git a/examples/advanced_config.rs b/examples/advanced_config.rs new file mode 100644 index 0000000..75df8db --- /dev/null +++ b/examples/advanced_config.rs @@ -0,0 +1,64 @@ +/// SDK Configuration Examples +/// +/// Shows different ways to configure the PostHog Rust SDK for various use cases. +/// +/// Run with: cargo run --example advanced_config --features async-client +use posthog_rs::{ClientOptionsBuilder, EU_INGESTION_ENDPOINT}; + +#[cfg(feature = "async-client")] +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("=== PostHog SDK Configuration Examples ===\n"); + + // 1. SIMPLEST: Just an API key (uses US endpoint by default) + println!("1. Basic client (US region):"); + let _basic = posthog_rs::client("phc_test_api_key").await; + println!(" → Created with default settings\n"); + + // 2. REGIONAL: EU data residency + println!("2. EU region client:"); + let _eu = posthog_rs::client(("phc_test_api_key", EU_INGESTION_ENDPOINT)).await; + println!(" → Data stays in EU (GDPR compliant)\n"); + + // 3. SELF-HOSTED: Your own PostHog instance + println!("3. Self-hosted instance:"); + let _custom = posthog_rs::client(("phc_test_api_key", "https://analytics.mycompany.com")).await; + println!(" → Uses your private PostHog deployment\n"); + + // 4. PRODUCTION: Common production settings + println!("4. Production configuration:"); + let production_config = ClientOptionsBuilder::default() + .api_key("phc_production_key".to_string()) + .host("https://eu.posthog.com") // Auto-detects and uses EU ingestion + .gzip(true) // Compress requests + .request_timeout_seconds(30) // 30s timeout + .build()?; + + let _prod = posthog_rs::client(production_config).await; + println!(" → Optimized for production workloads\n"); + + // 5. HIGH-PERFORMANCE: Local flag evaluation + println!("5. High-performance with local evaluation:"); + let performance_config = ClientOptionsBuilder::default() + .api_key("phc_project_key".to_string()) + .personal_api_key("phx_personal_key") // Required for local eval + .enable_local_evaluation(true) // Cache flags locally + .poll_interval_seconds(30) // Update cache every 30s + .feature_flags_request_timeout_seconds(3) + .build()?; + + let _perf = posthog_rs::client(performance_config).await; + println!(" → Evaluates flags locally (100x faster)\n"); + + println!("✅ Configuration examples complete!"); + println!("\nTip: Check out 'feature_flags' example for flag usage"); + println!(" and 'local_evaluation' for performance optimization."); + + Ok(()) +} + +#[cfg(not(feature = "async-client"))] +fn main() { + println!("This example requires the async-client feature."); + println!("Run with: cargo run --example advanced_config --features async-client"); +} diff --git a/examples/feature_flags.rs b/examples/feature_flags.rs new file mode 100644 index 0000000..da3c731 --- /dev/null +++ b/examples/feature_flags.rs @@ -0,0 +1,177 @@ +/// Feature Flags Example +/// +/// Shows all feature flag patterns: boolean flags, A/B tests, payloads, and targeting. +/// +/// Run with real API: +/// export POSTHOG_API_TOKEN=phc_your_key +/// cargo run --example feature_flags --features async-client +use posthog_rs::FlagValue; +use serde_json::json; +use std::collections::HashMap; + +#[cfg(feature = "async-client")] +#[tokio::main] +async fn main() { + // Try to get API key from environment, or use demo mode + let api_key = std::env::var("POSTHOG_API_TOKEN").unwrap_or_else(|_| { + println!("No POSTHOG_API_TOKEN found. Running in demo mode with mock data.\n"); + "demo_api_key".to_string() + }); + + let is_demo = api_key == "demo_api_key"; + + // Create client + let client = if is_demo { + create_demo_client().await + } else { + posthog_rs::client(api_key.as_str()).await + }; + + // Example 1: Simple boolean flag check + println!("=== Example 1: Boolean Feature Flag ==="); + let user_id = "user-123"; + + match client + .is_feature_enabled( + "new-dashboard".to_string(), + user_id.to_string(), + None, + None, + None, + ) + .await + { + Ok(enabled) => { + if enabled { + println!("✅ New dashboard is enabled for {}", user_id); + } else { + println!("❌ New dashboard is disabled for {}", user_id); + } + } + Err(e) => println!("Error checking flag: {}", e), + } + + // Example 2: Multivariate flag (A/B testing) + println!("\n=== Example 2: A/B Test Variant ==="); + + match client + .get_feature_flag( + "checkout-flow".to_string(), + user_id.to_string(), + None, + None, + None, + ) + .await + { + Ok(Some(FlagValue::String(variant))) => { + println!("User {} gets checkout variant: {}", user_id, variant); + match variant.as_str() { + "control" => println!(" → Show original checkout flow"), + "variant-a" => println!(" → Show streamlined checkout"), + "variant-b" => println!(" → Show one-click checkout"), + _ => println!(" → Unknown variant"), + } + } + Ok(Some(FlagValue::Boolean(enabled))) => { + println!("Checkout flow flag is a boolean: {}", enabled); + } + Ok(None) => { + println!("Checkout flow flag not found or not evaluated"); + } + Err(e) => println!("Error getting flag: {}", e), + } + + // Example 3: Using person properties for targeting + println!("\n=== Example 3: Property-based Targeting ==="); + + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("premium")); + properties.insert("country".to_string(), json!("US")); + properties.insert("account_age_days".to_string(), json!(45)); + + match client + .get_feature_flag( + "premium-features".to_string(), + user_id.to_string(), + None, + Some(properties.clone()), + None, + ) + .await + { + Ok(Some(FlagValue::Boolean(true))) => { + println!("✅ Premium features enabled (user matches targeting rules)"); + } + Ok(Some(FlagValue::Boolean(false))) => { + println!("❌ Premium features disabled (user doesn't match targeting rules)"); + } + Ok(Some(FlagValue::String(v))) => { + println!("Premium features variant: {}", v); + } + Ok(None) => { + println!("Premium features flag not found"); + } + Err(e) => println!("Error: {}", e), + } + + // Example 4: Getting all flags at once + println!("\n=== Example 4: Batch Flag Evaluation ==="); + + match client + .get_feature_flags(user_id.to_string(), None, Some(properties), None) + .await + { + Ok((flags, payloads)) => { + println!("All flags for {}:", user_id); + for (flag_key, flag_value) in flags { + match flag_value { + FlagValue::Boolean(b) => println!(" {}: {}", flag_key, b), + FlagValue::String(s) => println!(" {}: \"{}\"", flag_key, s), + } + } + + if !payloads.is_empty() { + println!("\nFlag payloads:"); + for (flag_key, payload) in payloads { + println!(" {}: {}", flag_key, payload); + } + } + } + Err(e) => println!("Error getting all flags: {}", e), + } + + // Example 5: Feature flag with payload + println!("\n=== Example 5: Feature Flag Payload ==="); + + match client + .get_feature_flag_payload("onboarding-config".to_string(), user_id.to_string()) + .await + { + Ok(Some(payload)) => { + println!("Onboarding configuration payload:"); + println!("{}", serde_json::to_string_pretty(&payload).unwrap()); + + // Use payload data + if let Some(steps) = payload.get("steps").and_then(|v| v.as_array()) { + println!("\nOnboarding steps: {} steps total", steps.len()); + } + } + Ok(None) => { + println!("No payload for onboarding-config flag"); + } + Err(e) => println!("Error getting payload: {}", e), + } +} + +#[cfg(feature = "async-client")] +async fn create_demo_client() -> posthog_rs::Client { + println!("Note: Running in demo mode. API calls will fail but code structure is shown.\n"); + posthog_rs::client(("demo_key", "https://demo.posthog.com")).await +} + +#[cfg(not(feature = "async-client"))] +fn main() { + println!("This example requires the async-client feature."); + println!("Run with: cargo run --example feature_flags --features async-client"); +} diff --git a/examples/local_evaluation.rs b/examples/local_evaluation.rs new file mode 100644 index 0000000..4a0f11b --- /dev/null +++ b/examples/local_evaluation.rs @@ -0,0 +1,156 @@ +/// Local Evaluation Performance Demo +/// +/// Shows 100-1000x faster flag evaluation by caching definitions locally. +/// +/// Setup: +/// export POSTHOG_API_TOKEN=phc_your_project_key +/// export POSTHOG_PERSONAL_API_TOKEN=phx_your_personal_key +/// cargo run --example local_evaluation --features async-client +/// +/// Get personal key at: https://app.posthog.com/me/settings +use posthog_rs::ClientOptionsBuilder; +use serde_json::json; +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +#[cfg(feature = "async-client")] +#[tokio::main] +async fn main() { + // Get API keys from environment + let api_key = match std::env::var("POSTHOG_API_TOKEN") { + Ok(key) => key, + Err(_) => { + eprintln!("Error: POSTHOG_API_TOKEN environment variable not set"); + eprintln!("Please set it to your PostHog project API token"); + eprintln!("\nExample: export POSTHOG_API_TOKEN=phc_..."); + std::process::exit(1); + } + }; + + let personal_key = match std::env::var("POSTHOG_PERSONAL_API_TOKEN") { + Ok(key) => key, + Err(_) => { + eprintln!("Error: POSTHOG_PERSONAL_API_TOKEN environment variable not set"); + eprintln!("Please set it to your PostHog personal API token"); + eprintln!("\nTo create a personal API key:"); + eprintln!("1. Go to https://app.posthog.com/me/settings"); + eprintln!("2. Click 'Create personal API key'"); + eprintln!("3. Export it: export POSTHOG_PERSONAL_API_TOKEN=phx_..."); + std::process::exit(1); + } + }; + + println!("=== Local Evaluation Performance Demo ===\n"); + + // Create client WITH local evaluation + let local_client = { + let options = ClientOptionsBuilder::default() + .api_key(api_key.clone()) + .personal_api_key(personal_key) + .enable_local_evaluation(true) + .poll_interval_seconds(30) // Poll for updates every 30 seconds + .build() + .unwrap(); + + posthog_rs::client(options).await + }; + + // Create client WITHOUT local evaluation (for comparison) + let api_client = { + let options = ClientOptionsBuilder::default() + .api_key(api_key) + .build() + .unwrap(); + + posthog_rs::client(options).await + }; + + // Give local evaluation time to fetch initial flags + println!("Fetching flag definitions for local evaluation..."); + tokio::time::sleep(Duration::from_secs(2)).await; + + // Test data + let user_id = "perf-test-user"; + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("enterprise")); + properties.insert("country".to_string(), json!("US")); + + // Performance comparison + println!("\n=== Performance Comparison ==="); + + // Test API evaluation speed + println!("\n1. API Evaluation (10 requests):"); + let start = Instant::now(); + for i in 0..10 { + let _ = api_client + .get_feature_flag( + "using-feature-flags".to_string(), + format!("{}-{}", user_id, i), + None, + Some(properties.clone()), + None, + ) + .await; + } + let api_duration = start.elapsed(); + println!( + " Time: {:?} total, {:?} per request", + api_duration, + api_duration / 10 + ); + + // Test local evaluation speed + println!("\n2. Local Evaluation (10 requests):"); + let start = Instant::now(); + for i in 0..10 { + let _ = local_client + .get_feature_flag( + "using-feature-flags".to_string(), + format!("{}-{}", user_id, i), + None, + Some(properties.clone()), + None, + ) + .await; + } + let local_duration = start.elapsed(); + println!( + " Time: {:?} total, {:?} per request", + local_duration, + local_duration / 10 + ); + + // Show speedup + let speedup = api_duration.as_micros() as f64 / local_duration.as_micros().max(1) as f64; + println!("\n📊 Local evaluation is {:.1}x faster!", speedup); + + // Demonstrate batch evaluation + println!("\n=== Batch Evaluation Demo ==="); + + let start = Instant::now(); + match local_client + .get_feature_flags(user_id.to_string(), None, Some(properties), None) + .await + { + Ok((flags, _)) => { + let duration = start.elapsed(); + println!("Evaluated {} flags in {:?}", flags.len(), duration); + + // Show some flag values + println!("\nSample flags:"); + for (key, value) in flags.iter().take(5) { + println!(" {}: {:?}", key, value); + } + } + Err(e) => println!("Error: {}", e), + } + + println!("\n✅ Local evaluation continues polling for updates in the background"); + println!(" Updates will be fetched every 30 seconds automatically"); +} + +#[cfg(not(feature = "async-client"))] +fn main() { + println!("This example requires the async-client feature."); + println!("Run with: cargo run --example local_evaluation --features async-client"); +} diff --git a/src/client/async_client.rs b/src/client/async_client.rs index 56139fb..a3a454e 100644 --- a/src/client/async_client.rs +++ b/src/client/async_client.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; use std::time::Duration; use reqwest::{header::CONTENT_TYPE, Client as HttpClient}; +use serde_json::json; +use crate::endpoints::{Endpoint, EndpointManager}; +use crate::feature_flags::{match_feature_flag, FeatureFlag, FeatureFlagsResponse, FlagValue}; +use crate::local_evaluation::{AsyncFlagPoller, FlagCache, LocalEvaluationConfig, LocalEvaluator}; use crate::{event::InnerEvent, Error, Event}; use super::ClientOptions; @@ -10,30 +15,85 @@ use super::ClientOptions; pub struct Client { options: ClientOptions, client: HttpClient, + local_evaluator: Option, + _flag_poller: Option, } /// This function constructs a new client using the options provided. pub async fn client>(options: C) -> Client { - let options = options.into(); + let mut options = options.into(); + // Ensure endpoint_manager is properly initialized based on the host + options.endpoint_manager = EndpointManager::new(options.host.clone()); let client = HttpClient::builder() .timeout(Duration::from_secs(options.request_timeout_seconds)) .build() .unwrap(); // Unwrap here is as safe as `HttpClient::new` - Client { options, client } + + let (local_evaluator, flag_poller) = if options.enable_local_evaluation { + if let Some(ref personal_key) = options.personal_api_key { + let cache = FlagCache::new(); + + let config = LocalEvaluationConfig { + personal_api_key: personal_key.clone(), + project_api_key: options.api_key.clone(), + api_host: options.endpoints().api_host(), + poll_interval: Duration::from_secs(options.poll_interval_seconds), + request_timeout: Duration::from_secs(options.request_timeout_seconds), + }; + + let mut poller = AsyncFlagPoller::new(config, cache.clone()); + poller.start().await; + + (Some(LocalEvaluator::new(cache)), Some(poller)) + } else { + eprintln!("[FEATURE FLAGS] Local evaluation enabled but personal_api_key not set"); + (None, None) + } + } else { + (None, None) + }; + + Client { + options, + client, + local_evaluator, + _flag_poller: flag_poller, + } } impl Client { /// Capture the provided event, sending it to PostHog. pub async fn capture(&self, event: Event) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + let inner_event = InnerEvent::new(event, self.options.api_key.clone()); let payload = serde_json::to_string(&inner_event).map_err(|e| Error::Serialization(e.to_string()))?; - self.client - .post(&self.options.api_endpoint) - .header(CONTENT_TYPE, "application/json") - .body(payload) + let mut url = self.options.endpoints().build_url(Endpoint::Capture); + if self.options.disable_geoip { + let separator = if url.contains('?') { "&" } else { "?" }; + url.push_str(&format!("{separator}disable_geoip=1")); + } + + let request = self + .client + .post(&url) + .header(CONTENT_TYPE, "application/json"); + + // Apply gzip compression if enabled + let request = if self.options.gzip { + // Note: reqwest will automatically compress the body when gzip feature is enabled + // and Content-Encoding header is set + request.header("Content-Encoding", "gzip").body(payload) + } else { + request.body(payload) + }; + + request .send() .await .map_err(|e| Error::Connection(e.to_string()))?; @@ -44,6 +104,10 @@ impl Client { /// Capture a collection of events with a single request. This function may be /// more performant than capturing a list of events individually. pub async fn capture_batch(&self, events: Vec) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + let events: Vec<_> = events .into_iter() .map(|event| InnerEvent::new(event, self.options.api_key.clone())) @@ -52,14 +116,212 @@ impl Client { let payload = serde_json::to_string(&events).map_err(|e| Error::Serialization(e.to_string()))?; - self.client - .post(&self.options.api_endpoint) - .header(CONTENT_TYPE, "application/json") - .body(payload) + let mut url = self.options.endpoints().build_url(Endpoint::Capture); + if self.options.disable_geoip { + let separator = if url.contains('?') { "&" } else { "?" }; + url.push_str(&format!("{separator}disable_geoip=1")); + } + + let request = self + .client + .post(&url) + .header(CONTENT_TYPE, "application/json"); + + // Apply gzip compression if enabled + let request = if self.options.gzip { + // Note: reqwest will automatically compress the body when gzip feature is enabled + // and Content-Encoding header is set + request.header("Content-Encoding", "gzip").body(payload) + } else { + request.body(payload) + }; + + request .send() .await .map_err(|e| Error::Connection(e.to_string()))?; Ok(()) } + + /// Get all feature flags for a user + pub async fn get_feature_flags>( + &self, + distinct_id: S, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result< + ( + HashMap, + HashMap, + ), + Error, + > { + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let mut payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id.into(), + }); + + if let Some(groups) = groups { + payload["groups"] = json!(groups); + } + + if let Some(person_properties) = person_properties { + payload["person_properties"] = json!(person_properties); + } + + if let Some(group_properties) = group_properties { + payload["group_properties"] = json!(group_properties); + } + + // Add geoip disable parameter if configured + if self.options.disable_geoip { + payload["disable_geoip"] = json!(true); + } + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .timeout(Duration::from_secs( + self.options.feature_flags_request_timeout_seconds, + )) + .send() + .await + .map_err(|e| Error::Connection(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status(); + let text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::Connection(format!( + "API request failed with status {status}: {text}" + ))); + } + + let flags_response = response.json::().await.map_err(|e| { + Error::Serialization(format!("Failed to parse feature flags response: {e}")) + })?; + + Ok(flags_response.normalize()) + } + + /// Get a specific feature flag value for a user + pub async fn get_feature_flag, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result, Error> { + let key_str = key.into(); + let distinct_id_str = distinct_id.into(); + + // Try local evaluation first if available + if let Some(ref evaluator) = self.local_evaluator { + let props = person_properties.clone().unwrap_or_default(); + match evaluator.evaluate_flag(&key_str, &distinct_id_str, &props) { + Ok(Some(value)) => return Ok(Some(value)), + Ok(None) => { + // Flag not found locally, fall through to API + } + Err(_) => { + // Inconclusive match, fall through to API + } + } + } + + // Fall back to API + let (feature_flags, _payloads) = self + .get_feature_flags(distinct_id_str, groups, person_properties, group_properties) + .await?; + Ok(feature_flags.get(&key_str).cloned()) + } + + /// Check if a feature flag is enabled for a user + pub async fn is_feature_enabled, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result { + let flag_value = self + .get_feature_flag( + key.into(), + distinct_id.into(), + groups, + person_properties, + group_properties, + ) + .await?; + Ok(match flag_value { + Some(FlagValue::Boolean(b)) => b, + Some(FlagValue::String(_)) => true, // Variants are considered enabled + None => false, + }) + } + + /// Get a feature flag payload for a user + pub async fn get_feature_flag_payload, D: Into>( + &self, + key: K, + distinct_id: D, + ) -> Result, Error> { + let key_str = key.into(); + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let mut payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id.into(), + }); + + // Add geoip disable parameter if configured + if self.options.disable_geoip { + payload["disable_geoip"] = json!(true); + } + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .timeout(Duration::from_secs( + self.options.feature_flags_request_timeout_seconds, + )) + .send() + .await + .map_err(|e| Error::Connection(e.to_string()))?; + + if !response.status().is_success() { + return Ok(None); + } + + let flags_response: FeatureFlagsResponse = response + .json() + .await + .map_err(|e| Error::Serialization(format!("Failed to parse response: {e}")))?; + + let (_flags, payloads) = flags_response.normalize(); + Ok(payloads.get(&key_str).cloned()) + } + + /// Evaluate a feature flag locally (requires feature flags to be loaded) + pub fn evaluate_feature_flag_locally( + &self, + flag: &FeatureFlag, + distinct_id: &str, + person_properties: &HashMap, + ) -> Result { + match_feature_flag(flag, distinct_id, person_properties) + .map_err(|e| Error::Connection(e.message)) + } } diff --git a/src/client/blocking.rs b/src/client/blocking.rs index 0f9af91..4ce90b3 100644 --- a/src/client/blocking.rs +++ b/src/client/blocking.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; use std::time::Duration; use reqwest::{blocking::Client as HttpClient, header::CONTENT_TYPE}; +use serde_json::json; +use crate::endpoints::{Endpoint, EndpointManager}; +use crate::feature_flags::{match_feature_flag, FeatureFlag, FeatureFlagsResponse, FlagValue}; +use crate::local_evaluation::{FlagCache, FlagPoller, LocalEvaluationConfig, LocalEvaluator}; use crate::{event::InnerEvent, Error, Event}; use super::ClientOptions; @@ -10,28 +15,67 @@ use super::ClientOptions; pub struct Client { options: ClientOptions, client: HttpClient, + local_evaluator: Option, + _flag_poller: Option, } /// This function constructs a new client using the options provided. pub fn client>(options: C) -> Client { - let options = options.into(); + let mut options = options.into(); + // Ensure endpoint_manager is properly initialized based on the host + options.endpoint_manager = EndpointManager::new(options.host.clone()); let client = HttpClient::builder() .timeout(Duration::from_secs(options.request_timeout_seconds)) .build() .unwrap(); // Unwrap here is as safe as `HttpClient::new` - Client { options, client } + + let (local_evaluator, flag_poller) = if options.enable_local_evaluation { + if let Some(ref personal_key) = options.personal_api_key { + let cache = FlagCache::new(); + + let config = LocalEvaluationConfig { + personal_api_key: personal_key.clone(), + project_api_key: options.api_key.clone(), + api_host: options.endpoints().api_host(), + poll_interval: Duration::from_secs(options.poll_interval_seconds), + request_timeout: Duration::from_secs(options.request_timeout_seconds), + }; + + let mut poller = FlagPoller::new(config, cache.clone()); + poller.start(); + + (Some(LocalEvaluator::new(cache)), Some(poller)) + } else { + eprintln!("[FEATURE FLAGS] Local evaluation enabled but personal_api_key not set"); + (None, None) + } + } else { + (None, None) + }; + + Client { + options, + client, + local_evaluator, + _flag_poller: flag_poller, + } } impl Client { /// Capture the provided event, sending it to PostHog. pub fn capture(&self, event: Event) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + let inner_event = InnerEvent::new(event, self.options.api_key.clone()); let payload = serde_json::to_string(&inner_event).map_err(|e| Error::Serialization(e.to_string()))?; + let url = self.options.endpoints().build_url(Endpoint::Capture); self.client - .post(&self.options.api_endpoint) + .post(&url) .header(CONTENT_TYPE, "application/json") .body(payload) .send() @@ -43,6 +87,10 @@ impl Client { /// Capture a collection of events with a single request. This function may be /// more performant than capturing a list of events individually. pub fn capture_batch(&self, events: Vec) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + let events: Vec<_> = events .into_iter() .map(|event| InnerEvent::new(event, self.options.api_key.clone())) @@ -51,8 +99,9 @@ impl Client { let payload = serde_json::to_string(&events).map_err(|e| Error::Serialization(e.to_string()))?; + let url = self.options.endpoints().build_url(Endpoint::Capture); self.client - .post(&self.options.api_endpoint) + .post(&url) .header(CONTENT_TYPE, "application/json") .body(payload) .send() @@ -60,4 +109,160 @@ impl Client { Ok(()) } + + /// Get all feature flags for a user + pub fn get_feature_flags( + &self, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result< + ( + HashMap, + HashMap, + ), + Error, + > { + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let mut payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id, + }); + + if let Some(groups) = groups { + payload["groups"] = json!(groups); + } + + if let Some(person_properties) = person_properties { + payload["person_properties"] = json!(person_properties); + } + + if let Some(group_properties) = group_properties { + payload["group_properties"] = json!(group_properties); + } + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .send() + .map_err(|e| Error::Connection(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status(); + let text = response + .text() + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::Connection(format!( + "API request failed with status {}: {}", + status, text + ))); + } + + let flags_response = response.json::().map_err(|e| { + Error::Serialization(format!("Failed to parse feature flags response: {}", e)) + })?; + + Ok(flags_response.normalize()) + } + + /// Get a specific feature flag value for a user + pub fn get_feature_flag( + &self, + key: String, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result, Error> { + // Try local evaluation first if available + if let Some(ref evaluator) = self.local_evaluator { + let props = person_properties.clone().unwrap_or_default(); + match evaluator.evaluate_flag(&key, &distinct_id, &props) { + Ok(Some(value)) => return Ok(Some(value)), + Ok(None) => { + // Flag not found locally, fall through to API + } + Err(_) => { + // Inconclusive match, fall through to API + } + } + } + + // Fall back to API + let (feature_flags, _payloads) = + self.get_feature_flags(distinct_id, groups, person_properties, group_properties)?; + Ok(feature_flags.get(&key).cloned()) + } + + /// Check if a feature flag is enabled for a user + pub fn is_feature_enabled( + &self, + key: String, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result { + let flag_value = self.get_feature_flag( + key.into(), + distinct_id.into(), + groups, + person_properties, + group_properties, + )?; + Ok(match flag_value { + Some(FlagValue::Boolean(b)) => b, + Some(FlagValue::String(_)) => true, // Variants are considered enabled + None => false, + }) + } + + /// Get a feature flag payload for a user + pub fn get_feature_flag_payload, D: Into>( + &self, + key: K, + distinct_id: D, + ) -> Result, Error> { + let key_str = key.into(); + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id.into(), + }); + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .send() + .map_err(|e| Error::Connection(e.to_string()))?; + + if !response.status().is_success() { + return Ok(None); + } + + let flags_response: FeatureFlagsResponse = response + .json() + .map_err(|e| Error::Serialization(format!("Failed to parse response: {}", e)))?; + + let (_flags, payloads) = flags_response.normalize(); + Ok(payloads.get(&key_str).cloned()) + } + + /// Evaluate a feature flag locally (requires feature flags to be loaded) + pub fn evaluate_feature_flag_locally( + &self, + flag: &FeatureFlag, + distinct_id: &str, + person_properties: &HashMap, + ) -> Result { + match_feature_flag(flag, distinct_id, person_properties) + .map_err(|e| Error::Connection(e.message)) + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index a60d772..fba49fe 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,4 +1,4 @@ -use crate::API_ENDPOINT; +use crate::endpoints::EndpointManager; use derive_builder::Builder; #[cfg(not(feature = "async-client"))] @@ -15,14 +15,68 @@ pub use async_client::client; #[cfg(feature = "async-client")] pub use async_client::Client; -#[derive(Builder)] +#[derive(Builder, Clone)] pub struct ClientOptions { - #[builder(default = "API_ENDPOINT.to_string()")] - api_endpoint: String, + /// Host URL for the PostHog API (defaults to US ingestion endpoint) + #[builder(setter(into, strip_option), default)] + host: Option, + + /// Project API key (required) api_key: String, + /// Request timeout in seconds #[builder(default = "30")] request_timeout_seconds: u64, + + /// Personal API key for fetching flag definitions (required for local evaluation) + #[builder(setter(into, strip_option), default)] + personal_api_key: Option, + + /// Enable local evaluation of feature flags + #[builder(default = "false")] + enable_local_evaluation: bool, + + /// Interval for polling flag definitions (in seconds) + #[builder(default = "30")] + poll_interval_seconds: u64, + + /// Enable gzip compression for requests + #[builder(default = "false")] + gzip: bool, + + /// Disable tracking (useful for development) + #[builder(default = "false")] + disabled: bool, + + /// Disable automatic geoip enrichment + #[builder(default = "false")] + disable_geoip: bool, + + /// Feature flags request timeout in seconds + #[builder(default = "3")] + feature_flags_request_timeout_seconds: u64, + + #[builder(setter(skip))] + #[builder(default = "EndpointManager::new(None)")] + endpoint_manager: EndpointManager, +} + +impl ClientOptions { + /// Get the endpoint manager + pub(crate) fn endpoints(&self) -> &EndpointManager { + &self.endpoint_manager + } + + /// Check if the client is disabled + pub fn is_disabled(&self) -> bool { + self.disabled + } + + /// Create ClientOptions with properly initialized endpoint_manager + fn with_endpoint_manager(mut self) -> Self { + self.endpoint_manager = EndpointManager::new(self.host.clone()); + self + } } impl From<&str> for ClientOptions { @@ -31,5 +85,18 @@ impl From<&str> for ClientOptions { .api_key(api_key.to_string()) .build() .expect("We always set the API key, so this is infallible") + .with_endpoint_manager() + } +} + +impl From<(&str, &str)> for ClientOptions { + /// Create options from API key and host + fn from((api_key, host): (&str, &str)) -> Self { + ClientOptionsBuilder::default() + .api_key(api_key.to_string()) + .host(host.to_string()) + .build() + .expect("We always set the API key, so this is infallible") + .with_endpoint_manager() } } diff --git a/src/endpoints.rs b/src/endpoints.rs new file mode 100644 index 0000000..18cc130 --- /dev/null +++ b/src/endpoints.rs @@ -0,0 +1,183 @@ +use std::fmt; + +/// US ingestion endpoint +pub const US_INGESTION_ENDPOINT: &str = "https://us.i.posthog.com"; + +/// EU ingestion endpoint +pub const EU_INGESTION_ENDPOINT: &str = "https://eu.i.posthog.com"; + +/// Default host (US by default) +pub const DEFAULT_HOST: &str = US_INGESTION_ENDPOINT; + +/// API endpoints for different operations +#[derive(Debug, Clone)] +pub enum Endpoint { + /// Event capture endpoint + Capture, + /// Feature flags endpoint + Flags, + /// Local evaluation endpoint + LocalEvaluation, +} + +impl Endpoint { + /// Get the path for this endpoint + pub fn path(&self) -> &str { + match self { + Endpoint::Capture => "/i/v0/e/", + Endpoint::Flags => "/flags/?v=2", + Endpoint::LocalEvaluation => "/api/feature_flag/local_evaluation/?send_cohorts", + } + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.path()) + } +} + +/// Manages PostHog API endpoints and host configuration +#[derive(Debug, Clone)] +pub struct EndpointManager { + base_host: String, + raw_host: String, +} + +impl EndpointManager { + /// Create a new endpoint manager with the given host + pub fn new(host: Option) -> Self { + let raw_host = host.clone().unwrap_or_else(|| DEFAULT_HOST.to_string()); + let base_host = Self::determine_server_host(host); + + Self { + base_host, + raw_host, + } + } + + /// Determine the actual server host based on the provided host + /// Similar to posthog-python's determine_server_host function + pub fn determine_server_host(host: Option) -> String { + let host_or_default = host.unwrap_or_else(|| DEFAULT_HOST.to_string()); + let trimmed_host = host_or_default.trim_end_matches('/'); + + match trimmed_host { + "https://app.posthog.com" | "https://us.posthog.com" => { + US_INGESTION_ENDPOINT.to_string() + } + "https://eu.posthog.com" => EU_INGESTION_ENDPOINT.to_string(), + _ => host_or_default, + } + } + + /// Get the base host URL (for constructing endpoints) + pub fn base_host(&self) -> &str { + &self.base_host + } + + /// Get the raw host (as provided by the user, used for session replay URLs) + pub fn raw_host(&self) -> &str { + &self.raw_host + } + + /// Build a full URL for a given endpoint + pub fn build_url(&self, endpoint: Endpoint) -> String { + format!( + "{}{}", + self.base_host.trim_end_matches('/'), + endpoint.path() + ) + } + + /// Build a URL with a custom path + pub fn build_custom_url(&self, path: &str) -> String { + let normalized_path = if path.starts_with('/') { + path.to_string() + } else { + format!("/{path}") + }; + format!( + "{}{}", + self.base_host.trim_end_matches('/'), + normalized_path + ) + } + + /// Build the local evaluation URL with a token + pub fn build_local_eval_url(&self, token: &str) -> String { + format!( + "{}/api/feature_flag/local_evaluation/?token={}&send_cohorts", + self.base_host.trim_end_matches('/'), + token + ) + } + + /// Get the base host for API operations (without the path) + pub fn api_host(&self) -> String { + self.base_host.trim_end_matches('/').to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_determine_server_host() { + assert_eq!( + EndpointManager::determine_server_host(None), + US_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://app.posthog.com".to_string())), + US_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://us.posthog.com".to_string())), + US_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://eu.posthog.com".to_string())), + EU_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://custom.domain.com".to_string())), + "https://custom.domain.com" + ); + } + + #[test] + fn test_build_url() { + let manager = EndpointManager::new(None); + + assert_eq!( + manager.build_url(Endpoint::Capture), + format!("{}/i/v0/e/", US_INGESTION_ENDPOINT) + ); + + assert_eq!( + manager.build_url(Endpoint::Flags), + format!("{}/flags/?v=2", US_INGESTION_ENDPOINT) + ); + } + + #[test] + fn test_build_custom_url() { + let manager = EndpointManager::new(Some("https://custom.com/".to_string())); + + assert_eq!( + manager.build_custom_url("/api/test"), + "https://custom.com/api/test" + ); + + assert_eq!( + manager.build_custom_url("api/test"), + "https://custom.com/api/test" + ); + } +} diff --git a/src/error.rs b/src/error.rs index e193a85..c9f1972 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,11 +3,11 @@ use std::fmt::{Display, Formatter}; impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Error::Connection(msg) => write!(f, "Connection Error: {}", msg), - Error::Serialization(msg) => write!(f, "Serialization Error: {}", msg), + Error::Connection(msg) => write!(f, "Connection Error: {msg}"), + Error::Serialization(msg) => write!(f, "Serialization Error: {msg}"), Error::AlreadyInitialized => write!(f, "Client already initialized"), Error::NotInitialized => write!(f, "Client not initialized"), - Error::InvalidTimestamp(msg) => write!(f, "Invalid Timestamp: {}", msg), + Error::InvalidTimestamp(msg) => write!(f, "Invalid Timestamp: {msg}"), } } } diff --git a/src/feature_flags.rs b/src/feature_flags.rs new file mode 100644 index 0000000..27fa916 --- /dev/null +++ b/src/feature_flags.rs @@ -0,0 +1,693 @@ +use serde::{Deserialize, Serialize}; +use sha1::{Digest, Sha1}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum FlagValue { + Boolean(bool), + String(String), +} + +#[derive(Debug)] +pub struct InconclusiveMatchError { + pub message: String, +} + +impl InconclusiveMatchError { + pub fn new(message: &str) -> Self { + Self { + message: message.to_string(), + } + } +} + +impl Default for FlagValue { + fn default() -> Self { + FlagValue::Boolean(false) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeatureFlag { + pub key: String, + pub active: bool, + #[serde(default)] + pub filters: FeatureFlagFilters, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct FeatureFlagFilters { + #[serde(default)] + pub groups: Vec, + #[serde(default)] + pub multivariate: Option, + #[serde(default)] + pub payloads: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeatureFlagCondition { + #[serde(default)] + pub properties: Vec, + pub rollout_percentage: Option, + pub variant: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Property { + pub key: String, + pub value: serde_json::Value, + #[serde(default = "default_operator")] + pub operator: String, + #[serde(rename = "type")] + pub property_type: Option, +} + +fn default_operator() -> String { + "exact".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct MultivariateFilter { + pub variants: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MultivariateVariant { + pub key: String, + pub rollout_percentage: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum FeatureFlagsResponse { + // v2 API format (/flags/?v=2) + V2 { + flags: HashMap, + #[serde(rename = "errorsWhileComputingFlags")] + #[serde(default)] + errors_while_computing_flags: bool, + }, + // Legacy format (old decide endpoint) + Legacy { + #[serde(rename = "featureFlags")] + feature_flags: HashMap, + #[serde(rename = "featureFlagPayloads")] + #[serde(default)] + feature_flag_payloads: HashMap, + #[serde(default)] + errors: Option>, + }, +} + +impl FeatureFlagsResponse { + /// Convert the response to a normalized format + pub fn normalize( + self, + ) -> ( + HashMap, + HashMap, + ) { + match self { + FeatureFlagsResponse::V2 { flags, .. } => { + let mut feature_flags = HashMap::new(); + let mut payloads = HashMap::new(); + + for (key, detail) in flags { + if detail.enabled { + if let Some(variant) = detail.variant { + feature_flags.insert(key.clone(), FlagValue::String(variant)); + } else { + feature_flags.insert(key.clone(), FlagValue::Boolean(true)); + } + } else { + feature_flags.insert(key.clone(), FlagValue::Boolean(false)); + } + + if let Some(metadata) = detail.metadata { + if let Some(payload) = metadata.payload { + payloads.insert(key, payload); + } + } + } + + (feature_flags, payloads) + } + FeatureFlagsResponse::Legacy { + feature_flags, + feature_flag_payloads, + .. + } => (feature_flags, feature_flag_payloads), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlagDetail { + pub key: String, + pub enabled: bool, + pub variant: Option, + #[serde(default)] + pub reason: Option, + #[serde(default)] + pub metadata: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlagReason { + pub code: String, + #[serde(default)] + pub condition_index: Option, + #[serde(default)] + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlagMetadata { + pub id: u64, + pub version: u32, + pub description: Option, + pub payload: Option, +} + +const LONG_SCALE: f64 = 0xFFFFFFFFFFFFFFFu64 as f64; // Must be exactly 15 F's to match Python SDK + +pub fn hash_key(key: &str, distinct_id: &str, salt: &str) -> f64 { + let hash_key = format!("{key}.{distinct_id}{salt}"); + let mut hasher = Sha1::new(); + hasher.update(hash_key.as_bytes()); + let result = hasher.finalize(); + let hex_str = format!("{result:x}"); + let hash_val = u64::from_str_radix(&hex_str[..15], 16).unwrap_or(0); + hash_val as f64 / LONG_SCALE +} + +pub fn get_matching_variant(flag: &FeatureFlag, distinct_id: &str) -> Option { + let hash_value = hash_key(&flag.key, distinct_id, "variant"); + let variants = flag.filters.multivariate.as_ref()?.variants.as_slice(); + + let mut value_min = 0.0; + for variant in variants { + let value_max = value_min + variant.rollout_percentage / 100.0; + if hash_value >= value_min && hash_value < value_max { + return Some(variant.key.clone()); + } + value_min = value_max; + } + None +} + +pub fn match_feature_flag( + flag: &FeatureFlag, + distinct_id: &str, + properties: &HashMap, +) -> Result { + if !flag.active { + return Ok(FlagValue::Boolean(false)); + } + + let conditions = &flag.filters.groups; + + // Sort conditions to evaluate variant overrides first + let mut sorted_conditions = conditions.clone(); + sorted_conditions.sort_by_key(|c| if c.variant.is_some() { 0 } else { 1 }); + + let mut is_inconclusive = false; + + for condition in sorted_conditions { + match is_condition_match(flag, distinct_id, &condition, properties) { + Ok(true) => { + if let Some(variant_override) = &condition.variant { + // Check if variant is valid + if let Some(ref multivariate) = flag.filters.multivariate { + let valid_variants: Vec = multivariate + .variants + .iter() + .map(|v| v.key.clone()) + .collect(); + + if valid_variants.contains(variant_override) { + return Ok(FlagValue::String(variant_override.clone())); + } + } + } + + // Try to get matching variant or return true + if let Some(variant) = get_matching_variant(flag, distinct_id) { + return Ok(FlagValue::String(variant)); + } + return Ok(FlagValue::Boolean(true)); + } + Ok(false) => continue, + Err(_) => { + is_inconclusive = true; + } + } + } + + if is_inconclusive { + return Err(InconclusiveMatchError::new( + "Can't determine if feature flag is enabled or not with given properties", + )); + } + + Ok(FlagValue::Boolean(false)) +} + +fn is_condition_match( + flag: &FeatureFlag, + distinct_id: &str, + condition: &FeatureFlagCondition, + properties: &HashMap, +) -> Result { + // Check properties first + for prop in &condition.properties { + if !match_property(prop, properties)? { + return Ok(false); + } + } + + // If all properties match (or no properties), check rollout percentage + if let Some(rollout_percentage) = condition.rollout_percentage { + let hash_value = hash_key(&flag.key, distinct_id, ""); + if hash_value > (rollout_percentage / 100.0) { + return Ok(false); + } + } + + Ok(true) +} + +fn match_property( + property: &Property, + properties: &HashMap, +) -> Result { + let value = match properties.get(&property.key) { + Some(v) => v, + None => { + // Handle is_not_set operator + if property.operator == "is_not_set" { + return Ok(true); + } + // Handle is_set operator + if property.operator == "is_set" { + return Ok(false); + } + // For other operators, missing property is inconclusive + return Err(InconclusiveMatchError::new(&format!( + "Property '{}' not found in provided properties", + property.key + ))); + } + }; + + Ok(match property.operator.as_str() { + "exact" => { + if property.value.is_array() { + if let Some(arr) = property.value.as_array() { + for val in arr { + if compare_values(val, value) { + return Ok(true); + } + } + return Ok(false); + } + } + compare_values(&property.value, value) + } + "is_not" => { + if property.value.is_array() { + if let Some(arr) = property.value.as_array() { + for val in arr { + if compare_values(val, value) { + return Ok(false); + } + } + return Ok(true); + } + } + !compare_values(&property.value, value) + } + "is_set" => true, // We already know the property exists + "is_not_set" => false, // We already know the property exists + "icontains" => { + let prop_str = value_to_string(value); + let search_str = value_to_string(&property.value); + prop_str.to_lowercase().contains(&search_str.to_lowercase()) + } + "not_icontains" => { + let prop_str = value_to_string(value); + let search_str = value_to_string(&property.value); + !prop_str.to_lowercase().contains(&search_str.to_lowercase()) + } + "regex" => { + let prop_str = value_to_string(value); + let regex_str = value_to_string(&property.value); + match regex::Regex::new(®ex_str) { + Ok(re) => re.is_match(&prop_str), + Err(_) => false, + } + } + "not_regex" => { + let prop_str = value_to_string(value); + let regex_str = value_to_string(&property.value); + match regex::Regex::new(®ex_str) { + Ok(re) => !re.is_match(&prop_str), + Err(_) => true, + } + } + "gt" | "gte" | "lt" | "lte" => compare_numeric(&property.operator, &property.value, value), + _ => false, + }) +} + +fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> bool { + // Case-insensitive string comparison + if let (Some(a_str), Some(b_str)) = (a.as_str(), b.as_str()) { + return a_str.eq_ignore_ascii_case(b_str); + } + + // Direct comparison for other types + a == b +} + +fn value_to_string(value: &serde_json::Value) -> String { + match value { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Number(n) => n.to_string(), + serde_json::Value::Bool(b) => b.to_string(), + _ => value.to_string(), + } +} + +fn compare_numeric( + operator: &str, + property_value: &serde_json::Value, + value: &serde_json::Value, +) -> bool { + let prop_num = match property_value { + serde_json::Value::Number(n) => n.as_f64(), + serde_json::Value::String(s) => s.parse::().ok(), + _ => None, + }; + + let val_num = match value { + serde_json::Value::Number(n) => n.as_f64(), + serde_json::Value::String(s) => s.parse::().ok(), + _ => None, + }; + + if let (Some(prop), Some(val)) = (prop_num, val_num) { + match operator { + "gt" => val > prop, + "gte" => val >= prop, + "lt" => val < prop, + "lte" => val <= prop, + _ => false, + } + } else { + // Fall back to string comparison + let prop_str = value_to_string(property_value); + let val_str = value_to_string(value); + match operator { + "gt" => val_str > prop_str, + "gte" => val_str >= prop_str, + "lt" => val_str < prop_str, + "lte" => val_str <= prop_str, + _ => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_hash_key() { + let hash = hash_key("test-flag", "user-123", ""); + assert!(hash >= 0.0 && hash <= 1.0); + + // Same inputs should produce same hash + let hash2 = hash_key("test-flag", "user-123", ""); + assert_eq!(hash, hash2); + + // Different inputs should produce different hash + let hash3 = hash_key("test-flag", "user-456", ""); + assert_ne!(hash, hash3); + } + + #[test] + fn test_simple_flag_match() { + let flag = FeatureFlag { + key: "test-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + assert_eq!(result, FlagValue::Boolean(true)); + } + + #[test] + fn test_property_matching() { + let prop = Property { + key: "country".to_string(), + value: json!("US"), + operator: "exact".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("country".to_string(), json!("US")); + + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("country".to_string(), json!("UK")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_multivariate_variants() { + let flag = FeatureFlag { + key: "test-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: Some(MultivariateFilter { + variants: vec![ + MultivariateVariant { + key: "control".to_string(), + rollout_percentage: 50.0, + }, + MultivariateVariant { + key: "test".to_string(), + rollout_percentage: 50.0, + }, + ], + }), + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + + match result { + FlagValue::String(variant) => { + assert!(variant == "control" || variant == "test"); + } + _ => panic!("Expected string variant"), + } + } + + #[test] + fn test_inactive_flag() { + let flag = FeatureFlag { + key: "inactive-flag".to_string(), + active: false, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + assert_eq!(result, FlagValue::Boolean(false)); + } + + #[test] + fn test_rollout_percentage() { + let flag = FeatureFlag { + key: "rollout-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(30.0), // 30% rollout + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + + // Test with multiple users to ensure distribution + let mut enabled_count = 0; + for i in 0..1000 { + let result = match_feature_flag(&flag, &format!("user-{}", i), &properties).unwrap(); + if result == FlagValue::Boolean(true) { + enabled_count += 1; + } + } + + // Should be roughly 30% enabled (allow for some variance) + assert!(enabled_count > 250 && enabled_count < 350); + } + + #[test] + fn test_regex_operator() { + let prop = Property { + key: "email".to_string(), + value: json!(".*@company\\.com$"), + operator: "regex".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("email".to_string(), json!("user@company.com")); + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("email".to_string(), json!("user@example.com")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_icontains_operator() { + let prop = Property { + key: "name".to_string(), + value: json!("ADMIN"), + operator: "icontains".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("name".to_string(), json!("admin_user")); + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("name".to_string(), json!("regular_user")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_numeric_operators() { + // Greater than + let prop_gt = Property { + key: "age".to_string(), + value: json!(18), + operator: "gt".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("age".to_string(), json!(25)); + assert!(match_property(&prop_gt, &properties).unwrap()); + + properties.insert("age".to_string(), json!(15)); + assert!(!match_property(&prop_gt, &properties).unwrap()); + + // Less than or equal + let prop_lte = Property { + key: "score".to_string(), + value: json!(100), + operator: "lte".to_string(), + property_type: None, + }; + + properties.insert("score".to_string(), json!(100)); + assert!(match_property(&prop_lte, &properties).unwrap()); + + properties.insert("score".to_string(), json!(101)); + assert!(!match_property(&prop_lte, &properties).unwrap()); + } + + #[test] + fn test_is_set_operator() { + let prop = Property { + key: "email".to_string(), + value: json!(true), + operator: "is_set".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("email".to_string(), json!("test@example.com")); + assert!(match_property(&prop, &properties).unwrap()); + + properties.remove("email"); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_is_not_set_operator() { + let prop = Property { + key: "phone".to_string(), + value: json!(true), + operator: "is_not_set".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("phone".to_string(), json!("+1234567890")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_empty_groups() { + let flag = FeatureFlag { + key: "empty-groups".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + assert_eq!(result, FlagValue::Boolean(false)); + } + + #[test] + fn test_hash_scale_constant() { + // Verify the constant is exactly 15 F's (not 16) + assert_eq!(LONG_SCALE, 0xFFFFFFFFFFFFFFFu64 as f64); + assert_ne!(LONG_SCALE, 0xFFFFFFFFFFFFFFFFu64 as f64); + } +} diff --git a/src/lib.rs b/src/lib.rs index 6fe46d5..e7615c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,10 @@ mod client; +mod endpoints; mod error; mod event; +mod feature_flags; mod global; - -const API_ENDPOINT: &str = "https://us.i.posthog.com/i/v0/e/"; +mod local_evaluation; // Public interface - any change to this is breaking! // Client @@ -13,12 +14,32 @@ pub use client::ClientOptions; pub use client::ClientOptionsBuilder; pub use client::ClientOptionsBuilderError; +// Endpoints +pub use endpoints::{ + Endpoint, EndpointManager, DEFAULT_HOST, EU_INGESTION_ENDPOINT, US_INGESTION_ENDPOINT, +}; + // Error pub use error::Error; // Event pub use event::Event; +// Feature Flags +pub use feature_flags::{ + match_feature_flag, FeatureFlag, FeatureFlagCondition, FeatureFlagFilters, + FeatureFlagsResponse, FlagDetail, FlagMetadata, FlagReason, FlagValue, InconclusiveMatchError, + MultivariateFilter, MultivariateVariant, Property, +}; + +// Local Evaluation +pub use local_evaluation::{ + Cohort, FlagCache, FlagPoller, LocalEvaluationConfig, LocalEvaluationResponse, LocalEvaluator, +}; + +#[cfg(feature = "async-client")] +pub use local_evaluation::AsyncFlagPoller; + // We expose a global capture function as a convenience, that uses a global client pub use global::capture; pub use global::disable as disable_global; diff --git a/src/local_evaluation.rs b/src/local_evaluation.rs new file mode 100644 index 0000000..eba19d0 --- /dev/null +++ b/src/local_evaluation.rs @@ -0,0 +1,409 @@ +use crate::feature_flags::{match_feature_flag, FeatureFlag, FlagValue, InconclusiveMatchError}; +use crate::Error; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LocalEvaluationResponse { + pub flags: Vec, + #[serde(default)] + pub group_type_mapping: HashMap, + #[serde(default)] + pub cohorts: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Cohort { + pub id: String, + pub name: String, + pub properties: serde_json::Value, +} + +/// Manages locally cached feature flags for evaluation +#[derive(Clone)] +pub struct FlagCache { + flags: Arc>>, + group_type_mapping: Arc>>, + cohorts: Arc>>, +} + +impl Default for FlagCache { + fn default() -> Self { + Self::new() + } +} + +impl FlagCache { + pub fn new() -> Self { + Self { + flags: Arc::new(RwLock::new(HashMap::new())), + group_type_mapping: Arc::new(RwLock::new(HashMap::new())), + cohorts: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub fn update(&self, response: LocalEvaluationResponse) { + let mut flags = self.flags.write().unwrap(); + flags.clear(); + for flag in response.flags { + flags.insert(flag.key.clone(), flag); + } + + let mut mapping = self.group_type_mapping.write().unwrap(); + *mapping = response.group_type_mapping; + + let mut cohorts = self.cohorts.write().unwrap(); + *cohorts = response.cohorts; + } + + pub fn get_flag(&self, key: &str) -> Option { + self.flags.read().unwrap().get(key).cloned() + } + + pub fn get_all_flags(&self) -> Vec { + self.flags.read().unwrap().values().cloned().collect() + } + + pub fn clear(&self) { + self.flags.write().unwrap().clear(); + self.group_type_mapping.write().unwrap().clear(); + self.cohorts.write().unwrap().clear(); + } +} + +/// Configuration for local evaluation +#[derive(Clone)] +pub struct LocalEvaluationConfig { + pub personal_api_key: String, + pub project_api_key: String, + pub api_host: String, + pub poll_interval: Duration, + pub request_timeout: Duration, +} + +/// Manages polling for feature flag definitions +pub struct FlagPoller { + config: LocalEvaluationConfig, + cache: FlagCache, + client: reqwest::blocking::Client, + stop_signal: Arc>, + thread_handle: Option>, +} + +impl FlagPoller { + pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self { + let client = reqwest::blocking::Client::builder() + .timeout(config.request_timeout) + .build() + .unwrap(); + + Self { + config, + cache, + client, + stop_signal: Arc::new(RwLock::new(false)), + thread_handle: None, + } + } + + /// Start the polling thread + pub fn start(&mut self) { + // Initial load + if let Err(e) = self.load_flags() { + eprintln!("Failed to load initial flags: {e}"); + } + + let config = self.config.clone(); + let cache = self.cache.clone(); + let stop_signal = self.stop_signal.clone(); + + let handle = std::thread::spawn(move || { + let client = reqwest::blocking::Client::builder() + .timeout(config.request_timeout) + .build() + .unwrap(); + + loop { + std::thread::sleep(config.poll_interval); + + if *stop_signal.read().unwrap() { + break; + } + + let url = format!( + "{}/api/feature_flag/local_evaluation/?token={}&send_cohorts", + config.api_host.trim_end_matches('/'), + config.project_api_key + ); + + match client + .get(&url) + .header( + "Authorization", + format!("Bearer {}", config.personal_api_key), + ) + .send() + { + Ok(response) => { + if response.status().is_success() { + match response.json::() { + Ok(data) => cache.update(data), + Err(e) => { + eprintln!("[FEATURE FLAGS] Failed to parse flag response: {e}") + } + } + } else { + eprintln!( + "[FEATURE FLAGS] Failed to fetch flags: HTTP {}", + response.status() + ); + } + } + Err(e) => eprintln!("[FEATURE FLAGS] Failed to fetch flags: {e}"), + } + } + }); + + self.thread_handle = Some(handle); + } + + /// Load flags synchronously + pub fn load_flags(&self) -> Result<(), Error> { + let url = format!( + "{}/api/feature_flag/local_evaluation/?token={}&send_cohorts", + self.config.api_host.trim_end_matches('/'), + self.config.project_api_key + ); + + let response = self + .client + .get(&url) + .header( + "Authorization", + format!("Bearer {}", self.config.personal_api_key), + ) + .send() + .map_err(|e| Error::Connection(e.to_string()))?; + + if !response.status().is_success() { + return Err(Error::Connection(format!("HTTP {}", response.status()))); + } + + let data = response + .json::() + .map_err(|e| Error::Serialization(e.to_string()))?; + + self.cache.update(data); + Ok(()) + } + + /// Stop the polling thread + pub fn stop(&mut self) { + *self.stop_signal.write().unwrap() = true; + if let Some(handle) = self.thread_handle.take() { + handle.join().ok(); + } + } +} + +impl Drop for FlagPoller { + fn drop(&mut self) { + self.stop(); + } +} + +/// Async version of the flag poller +#[cfg(feature = "async-client")] +pub struct AsyncFlagPoller { + config: LocalEvaluationConfig, + cache: FlagCache, + client: reqwest::Client, + stop_signal: Arc>, + task_handle: Option>, + is_running: Arc>, +} + +#[cfg(feature = "async-client")] +impl AsyncFlagPoller { + pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self { + let client = reqwest::Client::builder() + .timeout(config.request_timeout) + .build() + .unwrap(); + + Self { + config, + cache, + client, + stop_signal: Arc::new(tokio::sync::RwLock::new(false)), + task_handle: None, + is_running: Arc::new(tokio::sync::RwLock::new(false)), + } + } + + /// Start the polling task + pub async fn start(&mut self) { + // Check if already running + { + let mut is_running = self.is_running.write().await; + if *is_running { + return; // Already running + } + *is_running = true; + } + + // Initial load + if let Err(e) = self.load_flags().await { + eprintln!("[FEATURE FLAGS] Failed to load initial flags: {e}"); + } + + let config = self.config.clone(); + let cache = self.cache.clone(); + let stop_signal = self.stop_signal.clone(); + let is_running = self.is_running.clone(); + let client = self.client.clone(); + + let task = tokio::spawn(async move { + let mut interval = tokio::time::interval(config.poll_interval); + interval.tick().await; // Skip the first immediate tick + + loop { + tokio::select! { + _ = interval.tick() => { + if *stop_signal.read().await { + break; + } + + let url = format!( + "{}/api/feature_flag/local_evaluation/?token={}&send_cohorts", + config.api_host.trim_end_matches('/'), + config.project_api_key + ); + + match client + .get(&url) + .header("Authorization", format!("Bearer {}", config.personal_api_key)) + .send() + .await + { + Ok(response) => { + if response.status().is_success() { + match response.json::().await { + Ok(data) => cache.update(data), + Err(e) => eprintln!("[FEATURE FLAGS] Failed to parse flag response: {e}"), + } + } else { + eprintln!("[FEATURE FLAGS] Failed to fetch flags: HTTP {}", response.status()); + } + } + Err(e) => eprintln!("[FEATURE FLAGS] Failed to fetch flags: {e}"), + } + } + } + } + + // Clear running flag when task exits + *is_running.write().await = false; + }); + + self.task_handle = Some(task); + } + + /// Load flags asynchronously + pub async fn load_flags(&self) -> Result<(), Error> { + let url = format!( + "{}/api/feature_flag/local_evaluation/?token={}&send_cohorts", + self.config.api_host.trim_end_matches('/'), + self.config.project_api_key + ); + + let response = self + .client + .get(&url) + .header( + "Authorization", + format!("Bearer {}", self.config.personal_api_key), + ) + .send() + .await + .map_err(|e| Error::Connection(e.to_string()))?; + + if !response.status().is_success() { + return Err(Error::Connection(format!("HTTP {}", response.status()))); + } + + let data = response + .json::() + .await + .map_err(|e| Error::Serialization(e.to_string()))?; + + self.cache.update(data); + Ok(()) + } + + /// Stop the polling task + pub async fn stop(&mut self) { + *self.stop_signal.write().await = true; + if let Some(handle) = self.task_handle.take() { + handle.abort(); + } + *self.is_running.write().await = false; + } + + /// Check if the poller is running + pub async fn is_running(&self) -> bool { + *self.is_running.read().await + } +} + +#[cfg(feature = "async-client")] +impl Drop for AsyncFlagPoller { + fn drop(&mut self) { + // Abort the task if still running + if let Some(handle) = self.task_handle.take() { + handle.abort(); + } + } +} + +/// Evaluator for locally cached flags +pub struct LocalEvaluator { + cache: FlagCache, +} + +impl LocalEvaluator { + pub fn new(cache: FlagCache) -> Self { + Self { cache } + } + + /// Evaluate a feature flag locally + pub fn evaluate_flag( + &self, + key: &str, + distinct_id: &str, + person_properties: &HashMap, + ) -> Result, InconclusiveMatchError> { + match self.cache.get_flag(key) { + Some(flag) => match_feature_flag(&flag, distinct_id, person_properties).map(Some), + None => Ok(None), + } + } + + /// Get all flags and evaluate them + pub fn evaluate_all_flags( + &self, + distinct_id: &str, + person_properties: &HashMap, + ) -> HashMap> { + let mut results = HashMap::new(); + + for flag in self.cache.get_all_flags() { + let result = match_feature_flag(&flag, distinct_id, person_properties); + results.insert(flag.key.clone(), result); + } + + results + } +} diff --git a/tests/test.rs b/tests/test.rs index 4d27cca..1a37d12 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,31 @@ -#[cfg(feature = "e2e-test")] +#[cfg(all(feature = "e2e-test", feature = "async-client"))] +#[tokio::test] +async fn get_client_async() { + use dotenv::dotenv; + dotenv().ok(); // Load the .env file + println!("Loaded .env for tests"); + + // see https://us.posthog.com/project/115809/ for the e2e project + use posthog_rs::Event; + use std::collections::HashMap; + + let api_key = std::env::var("POSTHOG_RS_E2E_TEST_API_KEY").unwrap(); + let client = posthog_rs::client(api_key.as_str()).await; + + let mut child_map = HashMap::new(); + child_map.insert("child_key1", "child_value1"); + + let mut event = Event::new("e2e test event", "1234"); + event.insert_prop("key1", "value1").unwrap(); + event.insert_prop("key2", vec!["a", "b"]).unwrap(); + event.insert_prop("key3", child_map).unwrap(); + + client.capture(event).await.unwrap(); +} + +#[cfg(all(feature = "e2e-test", not(feature = "async-client")))] #[test] -fn get_client() { +fn get_client_blocking() { use dotenv::dotenv; dotenv().ok(); // Load the .env file println!("Loaded .env for tests"); diff --git a/tests/test_async.rs b/tests/test_async.rs new file mode 100644 index 0000000..6856ec1 --- /dev/null +++ b/tests/test_async.rs @@ -0,0 +1,422 @@ +#![cfg(feature = "async-client")] + +use httpmock::prelude::*; +use posthog_rs::FlagValue; +use serde_json::json; +use std::collections::HashMap; + +async fn create_test_client(base_url: String) -> posthog_rs::Client { + // Use the From implementation to ensure endpoint_manager is set up correctly + let options: posthog_rs::ClientOptions = (("test_api_key", base_url.as_str())).into(); + posthog_rs::client(options).await +} + +#[tokio::test] +async fn test_get_all_feature_flags() { + let server = MockServer::start(); + + let mock_response = json!({ + "featureFlags": { + "test-flag": true, + "disabled-flag": false, + "variant-flag": "control" + }, + "featureFlagPayloads": { + "variant-flag": { + "color": "blue", + "size": "large" + } + } + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user" + })); + then.status(200) + .header("content-type", "application/json") + .json_body(mock_response); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flags("test-user".to_string(), None, None, None) + .await; + + if let Err(e) = &result { + eprintln!("Error: {:?}", e); + eprintln!("Mock server URL: {}", server.base_url()); + } + assert!(result.is_ok()); + let (feature_flags, payloads) = result.unwrap(); + + assert_eq!( + feature_flags.get("test-flag"), + Some(&FlagValue::Boolean(true)) + ); + assert_eq!( + feature_flags.get("disabled-flag"), + Some(&FlagValue::Boolean(false)) + ); + assert_eq!( + feature_flags.get("variant-flag"), + Some(&FlagValue::String("control".to_string())) + ); + + assert!(payloads.contains_key("variant-flag")); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_is_feature_enabled() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "featureFlags": { + "enabled-flag": true, + "disabled-flag": false + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let enabled_result = client + .is_feature_enabled( + "enabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + let disabled_result = client + .is_feature_enabled( + "disabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(disabled_result.is_ok()); + assert_eq!(disabled_result.unwrap(), false); + + flags_mock.assert_hits(2); +} + +#[tokio::test] +async fn test_get_feature_flag_with_properties() { + let server = MockServer::start(); + + let person_properties = json!({ + "country": "US", + "age": 25, + "plan": "premium" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user", + "person_properties": person_properties + })); + then.status(200).json_body(json!({ + "featureFlags": { + "premium-feature": true + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let mut props = HashMap::new(); + props.insert("country".to_string(), json!("US")); + props.insert("age".to_string(), json!(25)); + props.insert("plan".to_string(), json!("premium")); + + let result = client + .get_feature_flag( + "premium-feature".to_string(), + "test-user".to_string(), + None, + Some(props), + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_multivariate_flag() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "featureFlags": { + "experiment": "variant-b" + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(FlagValue::String("variant-b".to_string())) + ); + + let enabled_result = client + .is_feature_enabled( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + flags_mock.assert_hits(2); +} + +#[tokio::test] +async fn test_api_error_handling() { + let server = MockServer::start(); + + let error_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(500).body("Internal Server Error"); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flags("test-user".to_string(), None, None, None) + .await; + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("500")); + + error_mock.assert(); +} + +#[tokio::test] +async fn test_get_feature_flag_payload() { + let server = MockServer::start(); + + let payload_data = json!({ + "steps": ["welcome", "profile", "preferences"], + "theme": "dark" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "featureFlags": { + "onboarding-flow": "variant-a" + }, + "featureFlagPayloads": { + "onboarding-flow": payload_data + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag_payload("onboarding-flow".to_string(), "test-user".to_string()) + .await; + + assert!(result.is_ok()); + let payload = result.unwrap(); + assert!(payload.is_some()); + + let payload_value = payload.unwrap(); + assert_eq!(payload_value["theme"], "dark"); + assert!(payload_value["steps"].is_array()); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_nonexistent_flag() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "featureFlags": {}, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag( + "nonexistent-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), None); + + let enabled_result = client + .is_feature_enabled( + "nonexistent-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), false); + + flags_mock.assert_hits(2); +} + +#[tokio::test] +async fn test_empty_distinct_id() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "" + })); + then.status(200).json_body(json!({ + "featureFlags": { + "test-flag": true + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag("test-flag".to_string(), "".to_string(), None, None, None) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_groups_parameter() { + let server = MockServer::start(); + + let groups_json = json!({ + "company": "acme-corp", + "team": "engineering" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user", + "groups": groups_json + })); + then.status(200).json_body(json!({ + "featureFlags": { + "team-feature": true + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let mut groups = HashMap::new(); + groups.insert("company".to_string(), "acme-corp".to_string()); + groups.insert("team".to_string(), "engineering".to_string()); + + let result = client + .get_feature_flag( + "team-feature".to_string(), + "test-user".to_string(), + Some(groups), + None, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_malformed_response() { + let server = MockServer::start(); + + let malformed_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).body("not json"); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flags("test-user".to_string(), None, None, None) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("expected")); + + malformed_mock.assert(); +} diff --git a/tests/test_blocking.rs b/tests/test_blocking.rs new file mode 100644 index 0000000..162bd49 --- /dev/null +++ b/tests/test_blocking.rs @@ -0,0 +1,222 @@ +#![cfg(not(feature = "async-client"))] + +use httpmock::prelude::*; +use posthog_rs::FlagValue; +use serde_json::json; +use std::collections::HashMap; + +fn create_test_client(base_url: String) -> posthog_rs::Client { + // Use the From implementation to ensure endpoint_manager is set up correctly + let options: posthog_rs::ClientOptions = (("test_api_key", base_url.as_str())).into(); + posthog_rs::client(options) +} + +#[test] +fn test_get_all_feature_flags() { + let server = MockServer::start(); + + let mock_response = json!({ + "featureFlags": { + "test-flag": true, + "disabled-flag": false, + "variant-flag": "control" + }, + "featureFlagPayloads": { + "variant-flag": { + "color": "blue", + "size": "large" + } + } + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user" + })); + then.status(200) + .header("content-type", "application/json") + .json_body(mock_response); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flags("test-user".to_string(), None, None, None); + + assert!(result.is_ok()); + let (feature_flags, payloads) = result.unwrap(); + + assert_eq!( + feature_flags.get("test-flag"), + Some(&FlagValue::Boolean(true)) + ); + assert_eq!( + feature_flags.get("disabled-flag"), + Some(&FlagValue::Boolean(false)) + ); + assert_eq!( + feature_flags.get("variant-flag"), + Some(&FlagValue::String("control".to_string())) + ); + + assert!(payloads.contains_key("variant-flag")); + + flags_mock.assert(); +} + +#[test] +fn test_is_feature_enabled() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "featureFlags": { + "enabled-flag": true, + "disabled-flag": false + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()); + + let enabled_result = client.is_feature_enabled( + "enabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + let disabled_result = client.is_feature_enabled( + "disabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(disabled_result.is_ok()); + assert_eq!(disabled_result.unwrap(), false); + + flags_mock.assert_hits(2); +} + +#[test] +fn test_get_feature_flag_with_properties() { + let server = MockServer::start(); + + let person_properties = json!({ + "country": "US", + "age": 25, + "plan": "premium" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user", + "person_properties": person_properties + })); + then.status(200).json_body(json!({ + "featureFlags": { + "premium-feature": true + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()); + + let mut props = HashMap::new(); + props.insert("country".to_string(), json!("US")); + props.insert("age".to_string(), json!(25)); + props.insert("plan".to_string(), json!("premium")); + + let result = client.get_feature_flag( + "premium-feature".to_string(), + "test-user".to_string(), + None, + Some(props), + None, + ); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[test] +fn test_multivariate_flag() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "featureFlags": { + "experiment": "variant-b" + }, + "featureFlagPayloads": {} + })); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flag( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(FlagValue::String("variant-b".to_string())) + ); + + let enabled_result = client.is_feature_enabled( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + flags_mock.assert_hits(2); +} + +#[test] +fn test_api_error_handling() { + let server = MockServer::start(); + + let error_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(500).body("Internal Server Error"); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flags("test-user".to_string(), None, None, None); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("500")); + + error_mock.assert(); +} diff --git a/tests/test_local_evaluation.rs b/tests/test_local_evaluation.rs new file mode 100644 index 0000000..34548cc --- /dev/null +++ b/tests/test_local_evaluation.rs @@ -0,0 +1,243 @@ +use httpmock::prelude::*; +use posthog_rs::{ + ClientOptionsBuilder, FeatureFlag, FeatureFlagCondition, FeatureFlagFilters, FlagCache, + FlagValue, LocalEvaluationResponse, LocalEvaluator, Property, +}; +use serde_json::json; +use std::collections::HashMap; +use std::time::Duration; + +#[test] +fn test_local_evaluation_basic() { + // Create a cache and evaluator + let cache = FlagCache::new(); + let evaluator = LocalEvaluator::new(cache.clone()); + + // Create a simple flag + let flag = FeatureFlag { + key: "test-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + // Update cache with the flag + let response = LocalEvaluationResponse { + flags: vec![flag], + group_type_mapping: HashMap::new(), + cohorts: HashMap::new(), + }; + cache.update(response); + + // Test evaluation + let properties = HashMap::new(); + let result = evaluator.evaluate_flag("test-flag", "user-123", &properties); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); +} + +#[test] +fn test_local_evaluation_with_properties() { + let cache = FlagCache::new(); + let evaluator = LocalEvaluator::new(cache.clone()); + + // Create a flag with property conditions + let flag = FeatureFlag { + key: "premium-feature".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![Property { + key: "plan".to_string(), + value: json!("premium"), + operator: "exact".to_string(), + property_type: None, + }], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + // Update cache + let response = LocalEvaluationResponse { + flags: vec![flag], + group_type_mapping: HashMap::new(), + cohorts: HashMap::new(), + }; + cache.update(response); + + // Test with matching properties + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("premium")); + + let result = evaluator.evaluate_flag("premium-feature", "user-123", &properties); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + // Test with non-matching properties + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("free")); + + let result = evaluator.evaluate_flag("premium-feature", "user-456", &properties); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(false))); +} + +#[test] +fn test_local_evaluation_missing_flag() { + let cache = FlagCache::new(); + let evaluator = LocalEvaluator::new(cache); + + let properties = HashMap::new(); + let result = evaluator.evaluate_flag("non-existent", "user-123", &properties); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), None); +} + +#[cfg(feature = "async-client")] +#[tokio::test] +async fn test_local_evaluation_with_mock_server() { + let server = MockServer::start(); + + // Mock the local evaluation endpoint + let mock_flags = json!({ + "flags": [ + { + "key": "feature-a", + "active": true, + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 50.0, + "variant": null + } + ], + "multivariate": null, + "payloads": {} + } + }, + { + "key": "feature-b", + "active": true, + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "@company.com", + "operator": "icontains" + } + ], + "rollout_percentage": 100.0, + "variant": null + } + ], + "multivariate": null, + "payloads": {} + } + } + ], + "group_type_mapping": {}, + "cohorts": {} + }); + + let eval_mock = server.mock(|when, then| { + when.method(GET) + .path("/api/feature_flag/local_evaluation/") + .header("Authorization", "Bearer test_personal_key") + .query_param("token", "test_project_key") + .query_param("send_cohorts", ""); + then.status(200).json_body(mock_flags); + }); + + // Create client with local evaluation enabled + let options = ClientOptionsBuilder::default() + .host(server.base_url()) + .api_key("test_project_key".to_string()) + .personal_api_key("test_personal_key".to_string()) + .enable_local_evaluation(true) + .poll_interval_seconds(60) + .build() + .unwrap(); + + let client = posthog_rs::client(options).await; + + // Give it a moment to load initial flags + tokio::time::sleep(Duration::from_millis(100)).await; + + // Test local evaluation + let mut properties = HashMap::new(); + properties.insert("email".to_string(), json!("test@company.com")); + + let result = client + .get_feature_flag("feature-b", "user-123", None, Some(properties), None) + .await; + + assert!(result.is_ok()); + // The actual result depends on whether the mock was hit and processed + + eval_mock.assert(); +} + +#[test] +fn test_cache_operations() { + let cache = FlagCache::new(); + + // Create multiple flags + let flags = vec![ + FeatureFlag { + key: "flag1".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![], + multivariate: None, + payloads: HashMap::new(), + }, + }, + FeatureFlag { + key: "flag2".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![], + multivariate: None, + payloads: HashMap::new(), + }, + }, + ]; + + let response = LocalEvaluationResponse { + flags: flags.clone(), + group_type_mapping: HashMap::new(), + cohorts: HashMap::new(), + }; + + cache.update(response); + + // Test get_flag + assert!(cache.get_flag("flag1").is_some()); + assert!(cache.get_flag("flag2").is_some()); + assert!(cache.get_flag("flag3").is_none()); + + // Test get_all_flags + let all_flags = cache.get_all_flags(); + assert_eq!(all_flags.len(), 2); + + // Test clear + cache.clear(); + assert!(cache.get_flag("flag1").is_none()); + assert_eq!(cache.get_all_flags().len(), 0); +}