diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a1d853..30bfdd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,44 @@ +## 0.6.0 - 2025-11-05 + +### Added + +- Feature flags support (boolean, multivariate, payloads) +- Local evaluation for 100-1000x faster flag evaluation with background polling +- Automatic `$feature_flag_called` event tracking with deduplication +- Property-based targeting and group (B2B) support +- New methods: `is_feature_enabled()`, `get_feature_flag()`, `get_feature_flags()`, `get_feature_flag_payload()` + +#### New Dependencies: + +- Added `sha1` for flag matching algorithms +- Added `regex` for property matching in feature flags +- Added `tokio` (optional) for async local evaluation with background polling +- Added `json` and `gzip` features to `reqwest` for flag payloads and compression +- Dev dependencies: `httpmock` for testing, `futures` for async tests + +## 0.5.0 - 2025-11-05 + +### Minor Changes + +Configuration system now accepts base URLs instead of full endpoint URLs +- Provide just the hostname (e.g., `https://eu.posthog.com`) +- SDK automatically appends `/i/v0/e/` for single events and `/batch/` for batch events +- Old format with full URLs still works - paths are automatically stripped and normalized +- Enables simultaneous use of both single-event and batch endpoints +## 0.4.0 - 2025-11-05 + +### Minor Changes + + - Refactored error handling to use organized error types (`TransportError`, `ValidationError`, `InitializationError`) with structured data (timeouts, status codes, batch sizes) that can be pattern matched + - Existing errors will continue to work with deprecation warnings. + + - New helper methods: + - `is_retryable()` - identifies transient errors (timeouts, 5xx, 429) + - `is_client_error()` - identifies 4xx errors + +#### New Dependencies: + - Added `thiserror` to reduce writing manual error handling boilerplate + ## 0.2.6 - 2025-01-08 diff --git a/Cargo.toml b/Cargo.toml index d7056c6..73f88f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,19 +12,29 @@ rust-version = "1.78.0" reqwest = { version = "0.11.3", default-features = false, features = [ "rustls-tls", "blocking", + "json", + "gzip", ] } serde = { version = "1.0.125", features = ["derive"] } chrono = { version = "0.4.19", features = ["serde"] } serde_json = "1.0.64" semver = "1.0.24" -derive_builder = "0.20.2" uuid = { version = "1.13.2", features = ["serde", "v7"] } +sha1 = "0.10" +regex = "1.10" +tokio = { version = "1", features = ["rt", "sync", "time", "macros"], optional = true } +url = "2.5" +thiserror = "2.0" [dev-dependencies] dotenv = "0.15.0" ctor = "0.1.26" +tokio = { version = "1", features = ["full"] } +httpmock = "0.7" +serde_json = "1.0" +futures = "0.3" [features] default = ["async-client"] e2e-test = [] -async-client = [] +async-client = ["tokio"] diff --git a/README.md b/README.md index 3fd7966..673c67c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PostHog Rust -Please see the main [PostHog docs](https://posthog.com/docs). +Please see the main [PostHog docs](https://posthog.com/docs) **This crate is under development** @@ -13,12 +13,243 @@ Add `posthog-rs` to your `Cargo.toml`. posthog-rs = "0.3.7" ``` +## Basic Usage (US Region) + ```rust +use posthog_rs::{Event, ClientOptionsBuilder}; + +// Simple initialization with API key (defaults to US region) let client = posthog_rs::client(env!("POSTHOG_API_KEY")); -let mut event = posthog_rs::Event::new("test", "1234"); -event.insert_prop("key1", "value1").unwrap(); -event.insert_prop("key2", vec!["a", "b"]).unwrap(); +// Create and send an event +let mut event = Event::new("user_signed_up", "user_distinct_id"); +event.insert_prop("plan", "premium").unwrap(); +event.insert_prop("source", "web").unwrap(); + +client.capture(event).unwrap(); +``` + +## EU Region Configuration + +```rust +use posthog_rs::{Event, ClientOptionsBuilder}; + +// Configure for EU region - just provide the base URL +let options = ClientOptionsBuilder::new() + .api_key("phc_your_api_key") + .api_endpoint("https://eu.posthog.com") // SDK handles /i/v0/e/ and /batch/ automatically + .build() + .unwrap(); + +let client = posthog_rs::client(options); +// Single event capture +let event = Event::new("user_signed_up", "user_distinct_id"); client.capture(event).unwrap(); + +// Batch event capture (uses same base URL, different endpoint path) +let events = vec![ + Event::new("page_view", "user_1"), + Event::new("button_click", "user_2"), +]; +client.capture_batch(events).unwrap(); +``` + +## Backward Compatibility + +Old format with full URLs still works - the SDK automatically normalizes them: + +```rust +// This still works - path is automatically stripped +let options = ClientOptionsBuilder::new() + .api_key("phc_your_api_key") + .api_endpoint("https://eu.posthog.com/i/v0/e/") // Gets normalized to base URL + .build() + .unwrap(); +``` + +## Feature Flags + +The SDK now supports PostHog feature flags, allowing you to control feature rollout and run A/B tests. + +### Basic Usage + +```rust +use posthog_rs::{ClientOptionsBuilder, FlagValue}; +use std::collections::HashMap; +use serde_json::json; + +let options = ClientOptionsBuilder::default() + .api_key("phc_your_project_key") + .build() + .unwrap(); + +let client = posthog_rs::client(options); + +// Check if a feature is enabled +let is_enabled = client.is_feature_enabled( + "feature-key".to_string(), + "user-id".to_string(), + None, None, None +).unwrap(); + +// Get feature flag value (boolean or variant) +match client.get_feature_flag( + "feature-key".to_string(), + "user-id".to_string(), + None, None, None +).unwrap() { + Some(FlagValue::Boolean(enabled)) => println!("Flag is: {}", enabled), + Some(FlagValue::String(variant)) => println!("Variant: {}", variant), + None => println!("Flag is disabled"), +} +``` + +### With Properties + +```rust +// Include person properties for flag evaluation +let mut person_props = HashMap::new(); +person_props.insert("plan".to_string(), json!("enterprise")); +person_props.insert("country".to_string(), json!("US")); + +let flag = client.get_feature_flag( + "premium-feature".to_string(), + "user-id".to_string(), + None, + Some(person_props), + None +).unwrap(); +``` + +### With Groups (B2B) + +```rust +// For B2B apps with group-based flags +let mut groups = HashMap::new(); +groups.insert("company".to_string(), "company-123".to_string()); + +let mut group_props = HashMap::new(); +let mut company_props = HashMap::new(); +company_props.insert("size".to_string(), json!(500)); +group_props.insert("company".to_string(), company_props); + +let flag = client.get_feature_flag( + "b2b-feature".to_string(), + "user-id".to_string(), + Some(groups), + None, + Some(group_props) +).unwrap(); ``` + +### Get All Flags + +```rust +// Get all feature flags for a user +let response = client.get_feature_flags( + "user-id".to_string(), + None, None, None +).unwrap(); + +for (key, value) in response.feature_flags { + println!("Flag {}: {:?}", key, value); +} +``` + +### Feature Flag Payloads + +```rust +// Get additional data associated with a feature flag +let payload = client.get_feature_flag_payload( + "onboarding-flow".to_string(), + "user-id".to_string() +).unwrap(); + +if let Some(data) = payload { + println!("Payload: {}", data); +} +``` + +### Local Evaluation (High Performance) + +For significantly faster flag evaluation, enable local evaluation to cache flag definitions locally: + +```rust +use posthog_rs::ClientOptionsBuilder; + +let options = ClientOptionsBuilder::default() + .api_key("phc_your_project_key") + .personal_api_key("phx_your_personal_key") // Required for local evaluation + .enable_local_evaluation(true) + .poll_interval_seconds(30) // Update cache every 30s + .build() + .unwrap(); + +let client = posthog_rs::client(options); + +// Flag evaluations now happen locally (no API calls needed) +let enabled = client.is_feature_enabled( + "new-feature".to_string(), + "user-123".to_string(), + None, None, None +).unwrap(); +``` + +**Performance:** Local evaluation is 100-1000x faster than API evaluation (~119µs vs ~125ms per request). + +Get your personal API key at: https://app.posthog.com/me/settings + +### Automatic Event Tracking + +The SDK automatically captures `$feature_flag_called` events when you evaluate feature flags. These events include: +- Feature flag key and response value +- Deduplication per user + flag + value combination +- Rich metadata (payloads, versions, request IDs) + +To disable automatic events globally: +```rust +let options = ClientOptionsBuilder::default() + .api_key("phc_your_key") + .send_feature_flag_events(false) + .build() + .unwrap(); +``` + +## Error Handling + +The SDK provides error handling with semantic categories: + +```rust +use posthog_rs::{Error, TransportError, ValidationError}; + +match client.capture(event).await { + Ok(_) => println!("Event sent successfully"), + Err(Error::Transport(TransportError::Timeout(duration))) => { + eprintln!("Request timed out after {:?}", duration); + // Retry logic here + } + Err(Error::Transport(TransportError::HttpError(401, _))) => { + eprintln!("Invalid API key - check your configuration"); + } + Err(e) if e.is_retryable() => { + // Automatically handles: timeouts, 5xx errors, 429 rate limits + tokio::time::sleep(Duration::from_secs(2)).await; + client.capture(event).await?; + } + Err(e) => eprintln!("Permanent error: {}", e), +} +``` + +### Error Categories + +- **TransportError**: Network issues (DNS, timeouts, HTTP errors, connection failures) +- **ValidationError**: Data problems (serialization, batch size, invalid timestamps) +- **InitializationError**: Configuration issues (already initialized, not initialized) + +### Helper Methods + +- `is_retryable()`: Returns `true` for transient errors (timeouts, 5xx, 429) +- `is_client_error()`: Returns `true` for 4xx HTTP errors + +See [`examples/error_classification.rs`](examples/error_classification.rs) for comprehensive error handling patterns diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..88f280f --- /dev/null +++ b/examples/README.md @@ -0,0 +1,143 @@ +# PostHog Rust SDK Examples + +This directory contains example applications demonstrating how to use the PostHog Rust SDK + +## Available Examples + +### 1. Feature Flags (`feature_flags.rs`) +Comprehensive feature flag operations and patterns. + +```bash +# With a PostHog API key +export POSTHOG_API_TOKEN=phc_your_key +cargo run --example feature_flags --features async-client + +# Without API key (demo mode - shows code structure) +cargo run --example feature_flags --features async-client +``` + +**Demonstrates:** +- **Example 1**: Simple boolean flag checks using `is_feature_enabled` +- **Example 2**: Multivariate flags (A/B testing) with `get_feature_flag` returning variants (control, variant-a, variant-b) +- **Example 3**: Property-based targeting with person properties (plan, country, account_age_days) +- **Example 4**: Groups (B2B) - Organization and team-level features with group properties: + - Company properties: name, plan, employees, industry + - Team properties: name, size +- **Example 5**: Batch flag evaluation with `get_feature_flags` - getting all flags and payloads at once +- **Example 6**: Feature flag payloads with `get_feature_flag_payload` - JSON configuration data + +### 2. Local Evaluation (`local_evaluation.rs`) +Performance optimization through local flag caching. + +```bash +export POSTHOG_API_TOKEN=phc_your_project_key +export POSTHOG_PERSONAL_API_TOKEN=phx_your_personal_key +cargo run --example local_evaluation --features async-client +``` + +**Requirements:** +- Project API token (`phc_...`) +- Personal API token (`phx_...`) - [Create one](https://app.posthog.com/me/settings) + +**Demonstrates:** +- Creating two clients for comparison: one with local evaluation, one without +- Performance benchmarking: 10 API requests vs 10 local evaluation requests +- Real-time speedup calculation showing 100x-1000x improvement +- Automatic background polling for flag definition updates (configurable with `poll_interval_seconds`) +- Batch flag evaluation performance with `get_feature_flags` +- Using `ClientOptionsBuilder` with `enable_local_evaluation(true)` and `personal_api_key` + +### 3. Feature Flag Events (`feature_flag_events.rs`) +Automatic `$feature_flag_called` event tracking. + +```bash +export POSTHOG_API_TOKEN=phc_your_project_key +cargo run --example feature_flag_events --features async-client +``` + +**Demonstrates:** +- **Example 1**: Automatic `$feature_flag_called` event capture when evaluating flags (enabled by default) +- **Example 2**: Event deduplication - same user + flag + value combination only sends event once +- **Example 3**: Events captured for different users (user-1, user-2, user-3) - each gets separate events +- **Example 4**: Multivariate flag events with variant information + +### 4. Advanced Configuration (`advanced_config.rs`) +SDK configuration patterns for different use cases. + +```bash +export POSTHOG_API_TOKEN=phc_your_key +export POSTHOG_PERSONAL_API_TOKEN=phx_your_personal_key +cargo run --example advanced_config --features async-client +``` + +**Shows 5 Configuration Patterns:** +1. **Basic client**: `posthog_rs::client("phc_test_api_key")` - US region by default +2. **EU data residency**: `posthog_rs::client(("phc_key", EU_INGESTION_ENDPOINT))` - GDPR compliant +3. **Self-hosted PostHog**: `posthog_rs::client(("phc_key", "https://analytics.mycompany.com"))` +4. **Production-optimized**: ClientOptionsBuilder with: + - `gzip(true)` - compress requests + - `request_timeout_seconds(30)` - 30s timeout +5. **High-performance**: Local evaluation with: + - `enable_local_evaluation(true)` - cache flags locally + - `poll_interval_seconds(30)` - update cache every 30s + - `feature_flags_request_timeout_seconds(3)` - faster timeouts + +## Quick Start + +The simplest way to get started: + +```bash +# Try feature flags without an API key (demo mode - shows code structure) +cargo run --example feature_flags --features async-client + +# Or with your PostHog account +export POSTHOG_API_TOKEN=phc_your_key +cargo run --example feature_flags --features async-client +``` + +## Key Concepts + +### Feature Flag Types + +1. **Boolean Flags**: Simple on/off toggles + ```rust + FlagValue::Boolean(true) // enabled + FlagValue::Boolean(false) // disabled + ``` + +2. **Multivariate Flags**: Multiple variants for A/B/n testing + ```rust + FlagValue::String("control") + FlagValue::String("variant-a") + FlagValue::String("variant-b") + ``` + +### Evaluation Methods + +1. **Remote Evaluation**: Calls PostHog API for the latest flag values (default) +2. **Local Evaluation**: Caches flag definitions locally for 100-1000x faster evaluation (requires personal API key) + +### Properties + +- **Person Properties**: User attributes (country, plan, account_age_days, email, etc.) +- **Group Properties**: Organization/team attributes for B2B apps (company plan, employees, team size, etc.) + +### Automatic Event Tracking + +When you evaluate feature flags, the SDK automatically sends `$feature_flag_called` events to PostHog (enabled by default). These events include: +- Which flags were checked (`$feature_flag`) +- What value was returned (`$feature_flag_response`) +- User identifier (`distinct_id`) +- User properties used for evaluation + +Events are deduplicated per user + flag + value combination to avoid duplicate tracking. + +## Common Use Cases + +1. **Feature Rollouts**: Gradually release features to users +2. **A/B Testing**: Test different variants (control, variant-a, variant-b) to measure impact +3. **User Targeting**: Enable features for specific user segments using person properties +4. **B2B Group Targeting**: Target entire organizations or teams using group properties +5. **Kill Switches**: Quickly disable problematic features +6. **Beta Programs**: Give early access to specific users +7. **Performance**: Use local evaluation for high-throughput applications (100-1000x faster) diff --git a/examples/advanced_config.rs b/examples/advanced_config.rs new file mode 100644 index 0000000..f5a2f61 --- /dev/null +++ b/examples/advanced_config.rs @@ -0,0 +1,88 @@ +/// SDK Configuration Examples +/// +/// Shows different ways to configure the PostHog Rust SDK for various use cases. +/// +/// Run with: cargo run --example advanced_config --features async-client +#[cfg(feature = "async-client")] +use posthog_rs::{ClientOptionsBuilder, EU_INGESTION_ENDPOINT}; + +#[cfg(feature = "async-client")] +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = match std::env::var("POSTHOG_API_TOKEN") { + Ok(key) => key, + Err(_) => { + eprintln!("Error: POSTHOG_API_TOKEN environment variable not set"); + eprintln!("Please set it to your PostHog project API token"); + eprintln!("\nExample: export POSTHOG_API_TOKEN=phc_..."); + std::process::exit(1); + } + }; + + let personal_key = match std::env::var("POSTHOG_PERSONAL_API_TOKEN") { + Ok(key) => key, + Err(_) => { + eprintln!("Error: POSTHOG_PERSONAL_API_TOKEN environment variable not set"); + eprintln!("Please set it to your PostHog personal API token"); + eprintln!("\nTo create a personal API key:"); + eprintln!("1. Go to https://app.posthog.com/me/settings"); + eprintln!("2. Click 'Create personal API key'"); + eprintln!("3. Export it: export POSTHOG_PERSONAL_API_TOKEN=phx_..."); + std::process::exit(1); + } + }; + + println!("=== PostHog SDK Configuration Examples ===\n"); + + // 1. SIMPLEST: Just an API key (uses US endpoint by default) + println!("1. Basic client (US region):"); + let _basic = posthog_rs::client("phc_test_api_key").await; + println!(" → Created with default settings\n"); + + // 2. REGIONAL: EU data residency + println!("2. EU region client:"); + let _eu = posthog_rs::client(("phc_test_api_key", EU_INGESTION_ENDPOINT)).await; + println!(" → Data stays in EU (GDPR compliant)\n"); + + // 3. SELF-HOSTED: Your own PostHog instance + println!("3. Self-hosted instance:"); + let _custom = posthog_rs::client(("phc_test_api_key", "https://analytics.mycompany.com")).await; + println!(" → Uses your private PostHog deployment\n"); + + // 4. PRODUCTION: Common production settings + println!("4. Production configuration:"); + let production_config = ClientOptionsBuilder::default() + .api_key(api_key.clone()) + .host("https://eu.posthog.com") // Auto-detects and uses EU ingestion + .gzip(true) // Compress requests + .request_timeout_seconds(30) // 30s timeout + .build()?; + + let _prod = posthog_rs::client(production_config).await; + println!(" → Optimized for production workloads\n"); + + // 5. HIGH-PERFORMANCE: Local flag evaluation + println!("5. High-performance with local evaluation:"); + let performance_config = ClientOptionsBuilder::default() + .api_key(api_key) + .personal_api_key(personal_key) // Required for local eval + .enable_local_evaluation(true) // Cache flags locally + .poll_interval_seconds(30) // Update cache every 30s + .feature_flags_request_timeout_seconds(3) + .build()?; + + let _perf = posthog_rs::client(performance_config).await; + println!(" → Evaluates flags locally (100x faster)\n"); + + println!("✅ Configuration examples complete!"); + println!("\nTip: Check out 'feature_flags' example for flag usage"); + println!(" and 'local_evaluation' for performance optimization."); + + Ok(()) +} + +#[cfg(not(feature = "async-client"))] +fn main() { + println!("This example requires the async-client feature."); + println!("Run with: cargo run --example advanced_config --features async-client"); +} diff --git a/examples/config_migration.rs b/examples/config_migration.rs new file mode 100644 index 0000000..da7c926 --- /dev/null +++ b/examples/config_migration.rs @@ -0,0 +1,131 @@ +/// Run this example (blocking/sync client only): +/// cargo run --example config_migration --no-default-features + +#[cfg(feature = "async-client")] +fn main() { + eprintln!("ERROR: This example only works with the blocking/sync client."); + eprintln!("Run with: cargo run --example config_migration --no-default-features"); +} + +#[cfg(not(feature = "async-client"))] +fn main() { + use posthog_rs::{ClientOptionsBuilder, Event}; + //-------------------------------- + // BEFORE MIGRATION + //-------------------------------- + + println!("\nExample 0: Before Migration"); + + // Before for single event capture + let options = ClientOptionsBuilder::new() + .api_key("phc_demo") + .api_endpoint("https://eu.posthog.com/i/v0/e/") + .build() + .unwrap(); + + println!("Single event: {}", "https://eu.posthog.com/i/v0/e/"); + + let client = posthog_rs::client(options); + + let event = Event::new("user_signed_up", "distinct_id_of_the_user"); + client.capture(event).unwrap(); + + // Before for batch event capture + let options = ClientOptionsBuilder::new() + .api_key("phc_demo") + .api_endpoint("https://eu.posthog.com/batch/") + .build() + .unwrap(); + + println!("Batch event: {}", "https://eu.posthog.com/i/v0/e/"); + + let client = posthog_rs::client(options); + + let events = vec![ + Event::new("user_signed_up", "distinct_id_of_the_user"), + Event::new("user_signed_up", "distinct_id_2_of_the_user"), + ]; + client.capture_batch(events).unwrap(); + + //-------------------------------- + // AFTER MIGRATION + //-------------------------------- + + // These changes are internal to the SDK and you don't need to do anything. + // We're just showing you what happens internally. + // In this example, we're using the hostname "https://eu.posthog.com" + // which will be normalized to "https://eu.posthog.com/i/v0/e/" for single event capture + // and "https://eu.posthog.com/batch/" for batch event capture. + + // Example 1: Default hostname (without endpoint url) + println!("\nExample 1: Default hostname - (without endpoint url)"); + let options = ClientOptionsBuilder::new() + .api_key("phc_demo") + .build() + .unwrap(); + + // both are internally smarted assigned + println!("Single event: {}", "https://eu.posthog.com/i/v0/e/"); + println!("Batch event: {}", "https://eu.posthog.com/batch/"); + + let _client = posthog_rs::client(options); + + // or + + // let client = posthog_rs::client(env!("POSTHOG_API_KEY")); + + let event = Event::new("user_signed_up", "distinct_id_of_the_user"); + client.capture(event).unwrap(); + + let events = vec![ + Event::new("user_signed_up", "distinct_id_of_the_user"), + Event::new("user_signed_up", "distinct_id_2_of_the_user"), + ]; + client.capture_batch(events).unwrap(); + + // Example 2: EU region with hostname + println!("\nExample 1: just hostname"); + let options = ClientOptionsBuilder::new() + .api_key("phc_demo") + .api_endpoint("https://eu.posthog.com") + .build() + .unwrap(); + + // both are internally smarted assigned + println!("Single event: {}", "https://eu.posthog.com/i/v0/e/"); + println!("Batch event: {}", "https://eu.posthog.com/batch/"); + + let client = posthog_rs::client(options); + + let event = Event::new("user_signed_up", "distinct_id_of_the_user"); + client.capture(event).unwrap(); + + let events = vec![ + Event::new("user_signed_up", "distinct_id_of_the_user"), + Event::new("user_signed_up", "distinct_id_2_of_the_user"), + ]; + client.capture_batch(events).unwrap(); + + // Example 3: Backward compatibility + println!("\nExample 3: Backward compatible (old full URL format still works)"); + let options = ClientOptionsBuilder::new() + .api_key("phc_demo") + .api_endpoint("https://eu.posthog.com/i/v0/e/") + .build() + .unwrap(); + + println!("Input: https://eu.posthog.com/i/v0/e/ or https://eu.posthog.com/batch/"); + println!("Single event: {}", "https://eu.posthog.com/i/v0/e/"); + println!("Batch event: {}", "https://eu.posthog.com/batch/"); + + let client = posthog_rs::client(options); + + let event = Event::new("user_signed_up", "distinct_id_of_the_user"); + client.capture(event).unwrap(); + + let events = vec![ + Event::new("user_signed_up", "distinct_id_of_the_user"), + Event::new("user_signed_up", "distinct_id_2_of_the_user"), + ]; + client.capture_batch(events).unwrap(); +} diff --git a/examples/error_classification.rs b/examples/error_classification.rs new file mode 100644 index 0000000..5b15039 --- /dev/null +++ b/examples/error_classification.rs @@ -0,0 +1,182 @@ +/// Error Classification Example +/// +/// Run this example (blocking/sync client only): +/// cargo run --example error_classification --no-default-features +#[cfg(feature = "async-client")] +fn main() { + eprintln!("ERROR: This example only works with the blocking/sync client."); + eprintln!("Run with: cargo run --example error_classification --no-default-features"); +} + +#[cfg(not(feature = "async-client"))] +fn main() { + use posthog_rs::{ClientOptionsBuilder, Event}; + + println!("Error Classification Examples"); + println!("───────────────────────────────────────"); + println!("Demonstrates pattern matching on specific error types with client calls"); + println!(); + + use posthog_rs::{Error, TransportError, ValidationError}; + + // Test different configurations to trigger various errors + let test_cases = vec![ + ( + "DNS Resolution Error", + "https://invalid-domain-xyz123.posthog.com", + "phc_test", + 30, + ), + ( + "HTTP 401 Error (Invalid API Key)", + "https://eu.posthog.com", + "phc_invalid_key_demo", + 30, + ), + ( + "Timeout Error (Very short timeout)", + "https://eu.posthog.com", + "phc_test", + 0, + ), // 0 second timeout will trigger timeout + ( + "Network Error", + "https://fake-endpoint-test.io", + "phc_test", + 5, + ), + ]; + + for (test_name, endpoint, api_key, timeout_secs) in test_cases { + println!("Test name: {}", test_name); + println!("───────────────────────────────────────"); + + let test_options = ClientOptionsBuilder::default() + .api_key(api_key.to_string()) + .api_endpoint(endpoint.to_string()) + .request_timeout_seconds(timeout_secs) + .build() + .unwrap(); + + let test_client = posthog_rs::client(test_options); + let event = Event::new("test_event", "user_123"); + + match test_client.capture(event) { + Ok(_) => println!("✓ Event sent successfully"), + Err(e) => { + // println!(); + + // Old way (still works for backward compatibility) + println!("OLD WAY (deprecated but backward compatible):"); + #[allow(deprecated)] + match &e { + Error::Connection(msg) => { + println!(" Connection error string: '{}'", msg); + println!( + " Would need to parse: if msg.contains(\"timeout\") {{ retry(); }}" + ); + println!(" Fragile - breaks if error format changes!"); + } + Error::Serialization(msg) => { + println!(" Serialization error string: '{}'", msg); + println!(" Would need regex to extract details from string"); + } + Error::InvalidTimestamp(msg) => { + println!(" Invalid timestamp string: '{}'", msg); + println!(" Must parse timestamp format from string"); + } + Error::AlreadyInitialized => { + println!(" Global client already initialized"); + println!(" No additional info available"); + } + Error::NotInitialized => { + println!(" Global client not initialized"); + println!(" No additional info available"); + } + _ => { + println!(" Unknown error type - can't extract details"); + println!(" Error: {}", e); + } + } + println!(); + + // New way (structured data) + println!("NEW WAY (type-safe pattern matching):"); + match &e { + Error::Transport(TransportError::Timeout(duration)) => { + println!(" ✓ Pattern matched: TransportError::Timeout"); + println!(" ✓ Extracted duration: {:?}", duration); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Retry with exponential backoff"); + } + Error::Transport(TransportError::HttpError(401, msg)) => { + println!(" ✓ Pattern matched: TransportError::HttpError(401)"); + println!(" ✓ Extracted message: {}", msg); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Check POSTHOG_API_KEY environment variable"); + } + Error::Transport(TransportError::HttpError(status, msg)) + if *status >= 400 && *status < 500 => + { + println!(" ✓ Pattern matched: TransportError::HttpError({})", status); + println!(" ✓ Extracted status: {} (Client Error)", status); + println!(" ✓ Extracted message: {}", msg); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Fix request (client error - won't retry)"); + } + Error::Transport(TransportError::HttpError(status, msg)) if *status >= 500 => { + println!(" ✓ Pattern matched: TransportError::HttpError({})", status); + println!(" ✓ Extracted status: {}", status); + println!(" ✓ Extracted message: {}", msg); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Retry (server error)"); + } + Error::Transport(TransportError::DnsResolution(host)) => { + println!(" ✓ Pattern matched: TransportError::DnsResolution"); + println!(" ✓ Extracted hostname: {}", host); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Check network/DNS configuration"); + } + Error::Transport(TransportError::NetworkUnreachable) => { + println!(" ✓ Pattern matched: TransportError::NetworkUnreachable"); + println!(" ✓ No parsing needed - clear error type"); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Check internet connection"); + } + Error::Validation(ValidationError::BatchSizeExceeded { size, max }) => { + println!(" ✓ Pattern matched: ValidationError::BatchSizeExceeded"); + println!(" ✓ Extracted size: {}, max: {}", size, max); + println!(" ✓ Calculated chunks: {}", (size + max - 1) / max); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Split batch into chunks"); + println!(); + println!(" Auto-split code:"); + println!(" for chunk in events.chunks({}) {{", max); + println!(" client.capture_batch(chunk)?;"); + println!(" }}"); + } + Error::Validation(ValidationError::PropertyTooLarge { key, size }) => { + println!(" ✓ Pattern matched: ValidationError::PropertyTooLarge"); + println!(" ✓ Extracted property: '{}', size: {} bytes", key, size); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + println!(" → Action: Truncate or remove property"); + } + _ => { + println!(" ✓ Other error handled"); + println!(" ✓ is_retryable(): {}", e.is_retryable()); + println!(" ✓ is_client_error(): {}", e.is_client_error()); + } + } + } + } + println!(); + } +} diff --git a/examples/feature_flags.rs b/examples/feature_flags.rs new file mode 100644 index 0000000..b3a7346 --- /dev/null +++ b/examples/feature_flags.rs @@ -0,0 +1,234 @@ +/// Feature Flags Example +/// +/// Shows all feature flag patterns: boolean flags, A/B tests, payloads, targeting, and B2B groups. +/// +/// Run with real API: +/// export POSTHOG_API_TOKEN=phc_your_key +/// cargo run --example feature_flags --features async-client + +#[cfg(feature = "async-client")] +use posthog_rs::FlagValue; +#[cfg(feature = "async-client")] +use serde_json::json; +#[cfg(feature = "async-client")] +use std::collections::HashMap; + +#[cfg(feature = "async-client")] +#[tokio::main] +async fn main() { + // Try to get API key from environment, or use demo mode + let api_key = std::env::var("POSTHOG_API_TOKEN").unwrap_or_else(|_| { + println!("No POSTHOG_API_TOKEN found. Running in demo mode with mock data.\n"); + "demo_api_key".to_string() + }); + + let is_demo = api_key == "demo_api_key"; + + // Create client + let client = if is_demo { + create_demo_client().await + } else { + posthog_rs::client(api_key.as_str()).await + }; + + // Example 1: Simple boolean flag check + println!("=== Example 1: Boolean Feature Flag ==="); + let user_id = "user-123"; + + match client + .is_feature_enabled( + "new-dashboard".to_string(), + user_id.to_string(), + None, + None, + None, + ) + .await + { + Ok(enabled) => { + if enabled { + println!("✅ New dashboard is enabled for user"); + } else { + println!("❌ New dashboard is disabled for user"); + } + } + Err(e) => println!("Error checking flag: {}", e), + } + + // Example 2: Multivariate flag (A/B testing) + println!("\n=== Example 2: A/B Test Variant ==="); + + match client + .get_feature_flag( + "checkout-flow".to_string(), + user_id.to_string(), + None, + None, + None, + ) + .await + { + Ok(Some(FlagValue::String(variant))) => { + println!("User gets checkout variant: {}", variant); + match variant.as_str() { + "control" => println!(" → Show original checkout flow"), + "variant-a" => println!(" → Show streamlined checkout"), + "variant-b" => println!(" → Show one-click checkout"), + _ => println!(" → Unknown variant"), + } + } + Ok(Some(FlagValue::Boolean(enabled))) => { + println!("Checkout flow flag is a boolean: {}", enabled); + } + Ok(None) => { + println!("Checkout flow flag not found or not evaluated"); + } + Err(e) => println!("Error getting flag: {}", e), + } + + // Example 3: Using person properties for targeting + println!("\n=== Example 3: Property-based Targeting ==="); + + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("premium")); + properties.insert("country".to_string(), json!("US")); + properties.insert("account_age_days".to_string(), json!(45)); + + match client + .get_feature_flag( + "premium-features".to_string(), + user_id.to_string(), + None, + Some(properties.clone()), + None, + ) + .await + { + Ok(Some(FlagValue::Boolean(true))) => { + println!("✅ Premium features enabled (user matches targeting rules)"); + } + Ok(Some(FlagValue::Boolean(false))) => { + println!("❌ Premium features disabled (user doesn't match targeting rules)"); + } + Ok(Some(FlagValue::String(v))) => { + println!("Premium features variant: {}", v); + } + Ok(None) => { + println!("Premium features flag not found"); + } + Err(e) => println!("Error: {}", e), + } + + // Example 4: Groups (B2B) - Organization-Level Features + println!("\n=== Example 4: Groups (B2B) - Organization-Level Features ==="); + + // Set up groups: mapping of group type to group key + let mut groups = HashMap::new(); + groups.insert("company".to_string(), "company_id_123".to_string()); + groups.insert("team".to_string(), "team_design".to_string()); + + // Set up group properties: nested HashMap with group type -> properties + let mut group_properties = HashMap::new(); + + // Company properties + let mut company_props = HashMap::new(); + company_props.insert("name".to_string(), json!("Acme Corp")); + company_props.insert("plan".to_string(), json!("enterprise")); + company_props.insert("employees".to_string(), json!(250)); + company_props.insert("industry".to_string(), json!("technology")); + group_properties.insert("company".to_string(), company_props); + + // Team properties + let mut team_props = HashMap::new(); + team_props.insert("name".to_string(), json!("Design Team")); + team_props.insert("size".to_string(), json!(12)); + group_properties.insert("team".to_string(), team_props); + + match client + .get_feature_flag( + "enterprise-analytics".to_string(), + user_id.to_string(), + Some(groups.clone()), + None, // person_properties + Some(group_properties.clone()), + ) + .await + { + Ok(Some(FlagValue::Boolean(true))) => { + println!("✅ Enterprise analytics enabled for company"); + println!(" → Company: Acme Corp (250 employees)"); + println!(" → Team: Design Team (12 members)"); + } + Ok(Some(FlagValue::Boolean(false))) => { + println!("❌ Enterprise analytics disabled for this company"); + } + Ok(Some(FlagValue::String(variant))) => { + println!("Enterprise analytics variant: {}", variant); + } + Ok(None) => { + println!("Enterprise analytics flag not found"); + } + Err(e) => println!("Error: {}", e), + } + + // Example 5: Getting all flags at once + println!("\n=== Example 5: Batch Flag Evaluation ==="); + + match client + .get_all_flags_and_payloads(user_id.to_string(), None, Some(properties), None) + .await + { + Ok((flags, payloads)) => { + println!("All flags for user"); + for (flag_key, flag_value) in flags { + match flag_value { + FlagValue::Boolean(b) => println!(" {}: {}", flag_key, b), + FlagValue::String(s) => println!(" {}: \"{}\"", flag_key, s), + } + } + + if !payloads.is_empty() { + println!("\nFlag payloads:"); + for (flag_key, payload) in payloads { + println!(" {}: {}", flag_key, payload); + } + } + } + Err(e) => println!("Error getting all flags: {}", e), + } + + // Example 6: Feature flag with payload + println!("\n=== Example 6: Feature Flag Payload ==="); + + match client + .get_all_flags_and_payloads(user_id.to_string(), None, None, None) + .await + { + Ok((_flags, payloads)) => { + if let Some(payload) = payloads.get("onboarding-config") { + println!("Onboarding configuration payload:"); + println!("{}", serde_json::to_string_pretty(&payload).unwrap()); + + // Use payload data + if let Some(steps) = payload.get("steps").and_then(|v| v.as_array()) { + println!("\nOnboarding steps: {} steps total", steps.len()); + } + } else { + println!("No payload for onboarding-config flag"); + } + } + Err(e) => println!("Error getting payloads: {}", e), + } +} + +#[cfg(feature = "async-client")] +async fn create_demo_client() -> posthog_rs::Client { + println!("Note: Running in demo mode. API calls will fail but code structure is shown.\n"); + posthog_rs::client(("demo_key", "https://demo.posthog.com")).await +} + +#[cfg(not(feature = "async-client"))] +fn main() { + println!("This example requires the async-client feature."); + println!("Run with: cargo run --example feature_flags --features async-client"); +} diff --git a/examples/local_evaluation.rs b/examples/local_evaluation.rs new file mode 100644 index 0000000..3d05a79 --- /dev/null +++ b/examples/local_evaluation.rs @@ -0,0 +1,162 @@ +/// Local Evaluation Performance Demo +/// +/// Shows 100-1000x faster flag evaluation by caching definitions locally. +/// +/// Setup: +/// export POSTHOG_API_TOKEN=phc_your_project_key +/// export POSTHOG_PERSONAL_API_TOKEN=phx_your_personal_key +/// cargo run --example local_evaluation --features async-client +/// +/// Get personal key at: https://app.posthog.com/me/settings +#[cfg(feature = "async-client")] +use posthog_rs::ClientOptionsBuilder; +#[cfg(feature = "async-client")] +use serde_json::json; +#[cfg(feature = "async-client")] +use std::collections::HashMap; +#[cfg(feature = "async-client")] +use std::time::{Duration, Instant}; + +#[cfg(feature = "async-client")] +#[tokio::main] +async fn main() { + // Get API keys from environment + let api_key = match std::env::var("POSTHOG_API_TOKEN") { + Ok(key) => key, + Err(_) => { + eprintln!("Error: POSTHOG_API_TOKEN environment variable not set"); + eprintln!("Please set it to your PostHog project API token"); + eprintln!("\nExample: export POSTHOG_API_TOKEN=phc_..."); + std::process::exit(1); + } + }; + + let personal_key = match std::env::var("POSTHOG_PERSONAL_API_TOKEN") { + Ok(key) => key, + Err(_) => { + eprintln!("Error: POSTHOG_PERSONAL_API_TOKEN environment variable not set"); + eprintln!("Please set it to your PostHog personal API token"); + eprintln!("\nTo create a personal API key:"); + eprintln!("1. Go to https://app.posthog.com/me/settings"); + eprintln!("2. Click 'Create personal API key'"); + eprintln!("3. Export it: export POSTHOG_PERSONAL_API_TOKEN=phx_..."); + std::process::exit(1); + } + }; + + println!("=== Local Evaluation Performance Demo ===\n"); + + // Create client WITH local evaluation + let local_client = { + let options = ClientOptionsBuilder::default() + .api_key(api_key.clone()) + .personal_api_key(personal_key) + .enable_local_evaluation(true) + .poll_interval_seconds(30) // Poll for updates every 30 seconds + // TODO: implement proper batching like (mpsc::channel) to enable with performance + .send_feature_flag_events(false) + .build() + .unwrap(); + + posthog_rs::client(options).await + }; + + // Create client WITHOUT local evaluation (for comparison) + let api_client = { + let options = ClientOptionsBuilder::default() + .api_key(api_key) + .build() + .unwrap(); + + posthog_rs::client(options).await + }; + + // Give local evaluation time to fetch initial flags + println!("Fetching flag definitions for local evaluation..."); + tokio::time::sleep(Duration::from_secs(2)).await; + + // Test data + let user_id = "perf-test-user"; + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("enterprise")); + properties.insert("country".to_string(), json!("US")); + + // Performance comparison + println!("\n=== Performance Comparison ==="); + + // Test API evaluation speed + println!("\n1. API Evaluation (10 requests):"); + let start = Instant::now(); + for i in 0..10 { + let _ = api_client + .get_feature_flag( + "using-feature-flags".to_string(), + format!("{}-{}", user_id, i), + None, + Some(properties.clone()), + None, + ) + .await; + } + let api_duration = start.elapsed(); + println!( + " Time: {:?} total, {:?} per request", + api_duration, + api_duration / 10 + ); + + // Test local evaluation speed + println!("\n2. Local Evaluation (10 requests):"); + let start = Instant::now(); + for i in 0..10 { + let _ = local_client + .get_feature_flag( + "using-feature-flags".to_string(), + format!("{}-{}", user_id, i), + None, + Some(properties.clone()), + None, + ) + .await; + } + let local_duration = start.elapsed(); + println!( + " Time: {:?} total, {:?} per request", + local_duration, + local_duration / 10 + ); + + // Show speedup + let speedup = api_duration.as_micros() as f64 / local_duration.as_micros().max(1) as f64; + println!("\n📊 Local evaluation is {:.1}x faster!", speedup); + + // Demonstrate batch evaluation + println!("\n=== Batch Evaluation Demo ==="); + + let start = Instant::now(); + match local_client + .get_all_flags(user_id.to_string(), None, Some(properties), None) + .await + { + Ok(flags) => { + let duration = start.elapsed(); + println!("Evaluated {} flags in {:?}", flags.len(), duration); + + // Show some flag values + println!("\nSample flags:"); + for (key, value) in flags.iter().take(5) { + println!(" {}: {:?}", key, value); + } + } + Err(e) => println!("Error: {}", e), + } + + println!("\n✅ Local evaluation continues polling for updates in the background"); + println!(" Updates will be fetched every 30 seconds automatically"); +} + +#[cfg(not(feature = "async-client"))] +fn main() { + println!("This example requires the async-client feature."); + println!("Run with: cargo run --example local_evaluation --features async-client"); +} diff --git a/src/client/async_client.rs b/src/client/async_client.rs index 56139fb..1f26795 100644 --- a/src/client/async_client.rs +++ b/src/client/async_client.rs @@ -1,15 +1,33 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, RwLock}; use std::time::Duration; use reqwest::{header::CONTENT_TYPE, Client as HttpClient}; +use serde_json::json; +use crate::endpoints::Endpoint; +use crate::error::{TransportError, ValidationError}; +use crate::feature_flags::{ + match_feature_flag, FeatureFlag, FeatureFlagsResponse, FlagDetail, FlagValue, +}; +use crate::local_evaluation::{AsyncFlagPoller, FlagCache, LocalEvaluationConfig, LocalEvaluator}; use crate::{event::InnerEvent, Error, Event}; use super::ClientOptions; +/// Maximum number of distinct IDs to track for feature flag deduplication +const MAX_DICT_SIZE: usize = 50_000; + /// A [`Client`] facilitates interactions with the PostHog API over HTTP. pub struct Client { options: ClientOptions, client: HttpClient, + local_evaluator: Option, + #[allow(dead_code)] + flag_poller: Option, + /// Tracks which feature flags have been called for deduplication. + /// Maps distinct_id -> set of feature flag keys that have been reported. + distinct_ids_feature_flags_reported: Arc>>>, } /// This function constructs a new client using the options provided. @@ -19,24 +37,82 @@ pub async fn client>(options: C) -> Client { .timeout(Duration::from_secs(options.request_timeout_seconds)) .build() .unwrap(); // Unwrap here is as safe as `HttpClient::new` - Client { options, client } + + let (local_evaluator, flag_poller) = if options.enable_local_evaluation { + // Safe to unwrap: validation in ClientOptions::build() ensures personal_api_key + // is always Some when enable_local_evaluation is true + let personal_key = options + .personal_api_key + .as_ref() + .expect("personal_api_key must be present when enable_local_evaluation is true"); + + let cache = FlagCache::new(); + + let config = LocalEvaluationConfig { + personal_api_key: personal_key.clone(), + project_api_key: options.api_key.clone(), + api_host: options.endpoints().api_host(), + poll_interval: Duration::from_secs(options.poll_interval_seconds), + request_timeout: Duration::from_secs(options.request_timeout_seconds), + }; + + let poller = AsyncFlagPoller::new(config, cache.clone()); + poller.start().await; + + (Some(LocalEvaluator::new(cache)), Some(poller)) + } else { + (None, None) + }; + + Client { + options, + client, + local_evaluator, + flag_poller, + distinct_ids_feature_flags_reported: Arc::new(RwLock::new(HashMap::new())), + } } impl Client { /// Capture the provided event, sending it to PostHog. pub async fn capture(&self, event: Event) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + + // Note: Infinite loop prevention for $feature_flag_called events + // If we ever implement automatic feature flag evaluation for all events + // (auto-adding of $feature/* properties), we must skip flag evaluation + // when event.event_name() == "$feature_flag_called" to prevent infinite + // loops. Current implementation doesn't auto-evaluate flags in capture(), + // so no loop risk exists today. + let inner_event = InnerEvent::new(event, self.options.api_key.clone()); - let payload = - serde_json::to_string(&inner_event).map_err(|e| Error::Serialization(e.to_string()))?; + let payload = serde_json::to_string(&inner_event) + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; - self.client - .post(&self.options.api_endpoint) - .header(CONTENT_TYPE, "application/json") - .body(payload) - .send() - .await - .map_err(|e| Error::Connection(e.to_string()))?; + let mut url = self.options.endpoints().build_url(Endpoint::Capture); + if self.options.disable_geoip { + let separator = if url.contains('?') { "&" } else { "?" }; + url.push_str(&format!("{separator}disable_geoip=1")); + } + + let request = self + .client + .post(&url) + .header(CONTENT_TYPE, "application/json"); + + // Apply gzip compression if enabled + let request = if self.options.gzip { + // Note: reqwest will automatically compress the body when gzip feature is enabled + // and Content-Encoding header is set + request.header("Content-Encoding", "gzip").body(payload) + } else { + request.body(payload) + }; + + request.send().await.map_err(TransportError::from)?; Ok(()) } @@ -44,22 +120,434 @@ impl Client { /// Capture a collection of events with a single request. This function may be /// more performant than capturing a list of events individually. pub async fn capture_batch(&self, events: Vec) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + let events: Vec<_> = events .into_iter() - .map(|event| InnerEvent::new(event, self.options.api_key.clone())) + .map(|event| InnerEvent::new(event, self.options.api_key.to_string())) .collect(); - let payload = - serde_json::to_string(&events).map_err(|e| Error::Serialization(e.to_string()))?; + let payload = serde_json::to_string(&events) + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; + + let mut url = self.options.endpoints().batch_event_endpoint(); + if self.options.disable_geoip { + let separator = if url.contains('?') { "&" } else { "?" }; + url.push_str(&format!("{separator}disable_geoip=1")); + } + + let request = self + .client + .post(&url) + .header(CONTENT_TYPE, "application/json"); + + // Apply gzip compression if enabled + let request = if self.options.gzip { + // Note: reqwest will automatically compress the body when gzip feature is enabled + // and Content-Encoding header is set + request.header("Content-Encoding", "gzip").body(payload) + } else { + request.body(payload) + }; + + request.send().await.map_err(TransportError::from)?; + + Ok(()) + } + + /// Internal method to capture $feature_flag_called events. + /// + /// Handles deduplication and event construction inline. + #[allow(clippy::too_many_arguments)] + async fn capture_feature_flag_called( + &self, + distinct_id: &str, + flag_key: &str, + flag_response: Option<&FlagValue>, + payload: Option, + locally_evaluated: bool, + groups: Option>, + disable_geoip: Option, + request_id: Option, + flag_details: Option<&FlagDetail>, + ) -> Result<(), Error> { + // Create the reported key for deduplication + // Format: "{key}_::null::" for None, "{key}_{value}" otherwise + let feature_flag_reported_key = match flag_response { + None => format!("{}_::null::", flag_key), + Some(flag_value) => format!("{}_{}", flag_key, flag_value), + }; + + // Check if already reported for deduplication + { + let reported = self.distinct_ids_feature_flags_reported.read().unwrap(); + if let Some(flags) = reported.get(distinct_id) { + if flags.contains(&feature_flag_reported_key) { + // Already reported, skip + return Ok(()); + } + } + } + + // Build event properties inline + let mut event = Event::new("$feature_flag_called", distinct_id); + + // Add required properties + event.insert_prop("$feature_flag", flag_key)?; + event.insert_prop("$feature_flag_response", flag_response)?; + event.insert_prop("locally_evaluated", locally_evaluated)?; + + // Add $feature/{key} property + event.insert_prop(format!("$feature/{}", flag_key), flag_response)?; + + // Add optional properties + if let Some(p) = payload { + event.insert_prop("$feature_flag_payload", p)?; + } + + // Add request_id if provided + if let Some(req_id) = request_id { + event.insert_prop("$feature_flag_request_id", req_id)?; + } + + // Add flag_details metadata if provided + if let Some(details) = flag_details { + // Add reason + if let Some(reason) = &details.reason { + if let Some(desc) = &reason.description { + event.insert_prop("$feature_flag_reason", desc.clone())?; + } + } + + // Add metadata (version and id) + if let Some(metadata) = &details.metadata { + event.insert_prop("$feature_flag_version", metadata.version)?; + event.insert_prop("$feature_flag_id", metadata.id)?; + } + } + + // Add groups if present + if let Some(g) = groups { + for (group_name, group_id) in g { + event.add_group(&group_name, &group_id); + } + } + + // Set disable_geoip on event if provided + if let Some(disable_geo) = disable_geoip { + if disable_geo { + event.insert_prop("$geoip_disable", true)?; + } + } + + // Capture the event + self.capture(event).await?; + + // Mark as reported (even if capture failed to avoid retry storms) + { + let mut reported = self.distinct_ids_feature_flags_reported.write().unwrap(); + + // Check size limit and evict if necessary + if reported.len() >= MAX_DICT_SIZE && !reported.contains_key(distinct_id) { + // Remove first entry (FIFO eviction) + if let Some(first_key) = reported.keys().next().cloned() { + reported.remove(&first_key); + } + } + + reported + .entry(distinct_id.to_string()) + .or_default() + .insert(feature_flag_reported_key); + } - self.client - .post(&self.options.api_endpoint) + Ok(()) + } + + /// Get all feature flags for a user + pub async fn get_feature_flags>( + &self, + distinct_id: S, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result { + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let mut payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id.into(), + }); + + if let Some(groups) = groups { + payload["groups"] = json!(groups); + } + + if let Some(person_properties) = person_properties { + payload["person_properties"] = json!(person_properties); + } + + if let Some(group_properties) = group_properties { + payload["group_properties"] = json!(group_properties); + } + + // Add geoip disable parameter if configured + if self.options.disable_geoip { + payload["disable_geoip"] = json!(true); + } + + let response = self + .client + .post(&flags_endpoint) .header(CONTENT_TYPE, "application/json") - .body(payload) + .json(&payload) + .timeout(Duration::from_secs( + self.options.feature_flags_request_timeout_seconds, + )) .send() .await - .map_err(|e| Error::Connection(e.to_string()))?; + .map_err(TransportError::from)?; - Ok(()) + if !response.status().is_success() { + let status = response.status(); + let text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(TransportError::HttpError( + status.as_u16(), + format!("API request failed with status {status}: {text}"), + ) + .into()); + } + + let flags_response = response.json::().await.map_err(|e| { + ValidationError::SerializationFailed(format!( + "Failed to parse feature flags response: {e}" + )) + })?; + + Ok(flags_response) + } + + /// Get a specific feature flag value with control over event capture + pub(crate) async fn get_feature_flag_with_options, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + send_feature_flag_events: bool, + ) -> Result, Error> { + let key = key.into(); + let distinct_id = distinct_id.into(); + let mut locally_evaluated = false; + let mut flag_value: Option = None; + let mut payload: Option = None; + let mut request_id: Option = None; + let mut flag_details_map: HashMap = HashMap::new(); + + // Try local evaluation first if available + if let Some(ref evaluator) = self.local_evaluator { + let props = person_properties.clone().unwrap_or_default(); + match evaluator.evaluate_flag(&key, &distinct_id, &props) { + Ok(Some(value)) => { + locally_evaluated = true; + flag_value = Some(value); + } + Ok(None) => { + // Flag not found locally, fall through to API + } + Err(_e) => { + // Inconclusive match, fall through to API + } + } + } + + // Fall back to API if not locally evaluated + if flag_value.is_none() { + match self + .get_feature_flags( + distinct_id.clone(), + groups.clone(), + person_properties, + group_properties, + ) + .await + { + Ok(response) => { + // Use helper methods to get flag value and payload + flag_value = response.get_flag_value(&key); + payload = response.get_flag_payload(&key); + request_id = response.request_id; + flag_details_map = response.flags; + } + Err(_e) => { + // Return None on error (graceful degradation) + flag_value = None; + } + } + } + + // Capture $feature_flag_called event if enabled + if self.options.send_feature_flag_events && send_feature_flag_events { + let _ = self + .capture_feature_flag_called( + &distinct_id, + &key, + flag_value.as_ref(), + payload, + locally_evaluated, + groups, + None, // disable_geoip - not yet supported + request_id, + flag_details_map.get(&key), + ) + .await; + } + + Ok(flag_value) + } + + /// Get a specific feature flag value for a user. + /// + /// Automatically captures a `$feature_flag_called` event to PostHog + /// unless `send_feature_flag_events` is disabled in ClientOptions. + pub async fn get_feature_flag, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result, Error> { + self.get_feature_flag_with_options( + key, + distinct_id, + groups, + person_properties, + group_properties, + true, // send_feature_flag_events + ) + .await + } + + /// Check if a feature flag is enabled for a user + /// + /// This method will automatically capture a `$feature_flag_called` event + /// unless `send_feature_flag_events` is disabled in ClientOptions. + pub async fn is_feature_enabled, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result { + let flag_value = self + .get_feature_flag( + key, + distinct_id, + groups, + person_properties, + group_properties, + ) + .await?; + Ok(match flag_value { + Some(FlagValue::Boolean(b)) => b, + Some(FlagValue::String(_)) => true, // Variants are considered enabled + None => false, + }) + } + + /// Get a feature flag payload for a user + #[allow(dead_code)] + pub(crate) async fn get_feature_flag_payload, D: Into>( + &self, + key: K, + distinct_id: D, + ) -> Result, Error> { + let key_str = key.into(); + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let mut payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id.into(), + }); + + // Add geoip disable parameter if configured + if self.options.disable_geoip { + payload["disable_geoip"] = json!(true); + } + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .timeout(Duration::from_secs( + self.options.feature_flags_request_timeout_seconds, + )) + .send() + .await + .map_err(TransportError::from)?; + + if !response.status().is_success() { + return Ok(None); + } + + let flags_response: FeatureFlagsResponse = response.json().await.map_err(|e| { + ValidationError::SerializationFailed(format!("Failed to parse response: {e}")) + })?; + + Ok(flags_response.get_flag_payload(&key_str)) + } + + /// Get all feature flags as a HashMap + pub async fn get_all_flags( + &self, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result, Error> { + let response = self + .get_feature_flags(distinct_id, groups, person_properties, group_properties) + .await?; + Ok(response.to_flag_values()) + } + + /// Get all feature flags and payloads + pub async fn get_all_flags_and_payloads( + &self, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result< + ( + HashMap, + HashMap, + ), + Error, + > { + let response = self + .get_feature_flags(distinct_id, groups, person_properties, group_properties) + .await?; + Ok((response.to_flag_values(), response.to_flag_payloads())) + } + + /// Evaluate a feature flag locally (requires feature flags to be loaded) + pub fn evaluate_feature_flag_locally( + &self, + flag: &FeatureFlag, + distinct_id: &str, + person_properties: &HashMap, + ) -> Result { + match_feature_flag(flag, distinct_id, person_properties) + .map_err(|e| ValidationError::SerializationFailed(e.message).into()) } } diff --git a/src/client/blocking.rs b/src/client/blocking.rs index 0f9af91..6399fb4 100644 --- a/src/client/blocking.rs +++ b/src/client/blocking.rs @@ -1,15 +1,33 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, RwLock}; use std::time::Duration; use reqwest::{blocking::Client as HttpClient, header::CONTENT_TYPE}; +use serde_json::json; +use crate::endpoints::Endpoint; +use crate::error::{TransportError, ValidationError}; +use crate::feature_flags::{ + match_feature_flag, FeatureFlag, FeatureFlagsResponse, FlagDetail, FlagValue, +}; +use crate::local_evaluation::{FlagCache, FlagPoller, LocalEvaluationConfig, LocalEvaluator}; use crate::{event::InnerEvent, Error, Event}; use super::ClientOptions; +/// Maximum number of distinct IDs to track for feature flag deduplication +const MAX_DICT_SIZE: usize = 50_000; + /// A [`Client`] facilitates interactions with the PostHog API over HTTP. pub struct Client { options: ClientOptions, client: HttpClient, + local_evaluator: Option, + #[allow(dead_code)] + flag_poller: Option, + /// Tracks which feature flags have been called for deduplication. + /// Maps distinct_id -> set of feature flag keys that have been reported. + distinct_ids_feature_flags_reported: Arc>>>, } /// This function constructs a new client using the options provided. @@ -19,23 +37,68 @@ pub fn client>(options: C) -> Client { .timeout(Duration::from_secs(options.request_timeout_seconds)) .build() .unwrap(); // Unwrap here is as safe as `HttpClient::new` - Client { options, client } + + let (local_evaluator, flag_poller) = if options.enable_local_evaluation { + // Safe to unwrap: validation in ClientOptions::build() ensures personal_api_key + // is always Some when enable_local_evaluation is true + let personal_key = options + .personal_api_key + .as_ref() + .expect("personal_api_key must be present when enable_local_evaluation is true"); + + let cache = FlagCache::new(); + + let config = LocalEvaluationConfig { + personal_api_key: personal_key.clone(), + project_api_key: options.api_key.clone(), + api_host: options.endpoints().api_host(), + poll_interval: Duration::from_secs(options.poll_interval_seconds), + request_timeout: Duration::from_secs(options.request_timeout_seconds), + }; + + let poller = FlagPoller::new(config, cache.clone()); + poller.start(); + + (Some(LocalEvaluator::new(cache)), Some(poller)) + } else { + (None, None) + }; + + Client { + options, + client, + local_evaluator, + flag_poller, + distinct_ids_feature_flags_reported: Arc::new(RwLock::new(HashMap::new())), + } } impl Client { /// Capture the provided event, sending it to PostHog. pub fn capture(&self, event: Event) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + + // Note: Infinite loop prevention for $feature_flag_called events + // If we ever implement automatic feature flag evaluation for all events + // (auto-adding of $feature/* properties), we must skip flag evaluation + // when event.event_name() == "$feature_flag_called" to prevent infinite + // loops. Current implementation doesn't auto-evaluate flags in capture(), + // so no loop risk exists today. + let inner_event = InnerEvent::new(event, self.options.api_key.clone()); - let payload = - serde_json::to_string(&inner_event).map_err(|e| Error::Serialization(e.to_string()))?; + let payload = serde_json::to_string(&inner_event) + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; + let url = self.options.endpoints().build_url(Endpoint::Capture); self.client - .post(&self.options.api_endpoint) + .post(&url) .header(CONTENT_TYPE, "application/json") .body(payload) .send() - .map_err(|e| Error::Connection(e.to_string()))?; + .map_err(TransportError::from)?; Ok(()) } @@ -43,21 +106,391 @@ impl Client { /// Capture a collection of events with a single request. This function may be /// more performant than capturing a list of events individually. pub fn capture_batch(&self, events: Vec) -> Result<(), Error> { + if self.options.is_disabled() { + return Ok(()); + } + let events: Vec<_> = events .into_iter() - .map(|event| InnerEvent::new(event, self.options.api_key.clone())) + .map(|event| InnerEvent::new(event, self.options.api_key.to_string())) .collect(); - let payload = - serde_json::to_string(&events).map_err(|e| Error::Serialization(e.to_string()))?; + let payload = serde_json::to_string(&events) + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; + let url = self.options.endpoints().batch_event_endpoint(); self.client - .post(&self.options.api_endpoint) + .post(&url) .header(CONTENT_TYPE, "application/json") .body(payload) .send() - .map_err(|e| Error::Connection(e.to_string()))?; + .map_err(TransportError::from)?; + + Ok(()) + } + + /// Get all feature flags for a user + pub fn get_feature_flags( + &self, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result { + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let mut payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id, + }); + + if let Some(groups) = groups { + payload["groups"] = json!(groups); + } + + if let Some(person_properties) = person_properties { + payload["person_properties"] = json!(person_properties); + } + + if let Some(group_properties) = group_properties { + payload["group_properties"] = json!(group_properties); + } + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .send() + .map_err(TransportError::from)?; + + if !response.status().is_success() { + let status = response.status(); + let text = response + .text() + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(TransportError::HttpError( + status.as_u16(), + format!("API request failed with status {}: {}", status, text), + ) + .into()); + } + + let flags_response = response.json::().map_err(|e| { + ValidationError::SerializationFailed(format!( + "Failed to parse feature flags response: {}", + e + )) + })?; + + Ok(flags_response) + } + + /// Internal method to capture $feature_flag_called events. + /// + /// Handles deduplication and event construction inline. + #[allow(clippy::too_many_arguments)] + fn capture_feature_flag_called( + &self, + distinct_id: &str, + flag_key: &str, + flag_response: Option<&FlagValue>, + payload: Option, + locally_evaluated: bool, + groups: Option>, + disable_geoip: Option, + request_id: Option, + flag_details: Option<&FlagDetail>, + ) -> Result<(), Error> { + // Create the reported key for deduplication + // Format: "{key}_::null::" for None, "{key}_{value}" otherwise + let feature_flag_reported_key = match flag_response { + None => format!("{}_::null::", flag_key), + Some(flag_value) => format!("{}_{}", flag_key, flag_value), + }; + + // Check if already reported for deduplication + { + let reported = self.distinct_ids_feature_flags_reported.read().unwrap(); + if let Some(flags) = reported.get(distinct_id) { + if flags.contains(&feature_flag_reported_key) { + // Already reported, skip + return Ok(()); + } + } + } + + // Build event properties inline + let mut event = Event::new("$feature_flag_called", distinct_id); + + // Add required properties + event.insert_prop("$feature_flag", flag_key)?; + event.insert_prop("$feature_flag_response", &flag_response)?; + event.insert_prop("locally_evaluated", locally_evaluated)?; + + // Add $feature/{key} property + event.insert_prop(format!("$feature/{}", flag_key), &flag_response)?; + + // Add optional properties + if let Some(p) = payload { + event.insert_prop("$feature_flag_payload", p)?; + } + + // Add request_id if provided + if let Some(req_id) = request_id { + event.insert_prop("$feature_flag_request_id", req_id)?; + } + + // Add flag_details metadata if provided + if let Some(details) = flag_details { + // Add reason + if let Some(reason) = &details.reason { + if let Some(desc) = &reason.description { + event.insert_prop("$feature_flag_reason", desc.clone())?; + } + } + + // Add metadata (version and id) + if let Some(metadata) = &details.metadata { + event.insert_prop("$feature_flag_version", metadata.version)?; + event.insert_prop("$feature_flag_id", metadata.id)?; + } + } + + // Add groups if present + if let Some(g) = groups { + for (group_name, group_id) in g { + event.add_group(&group_name, &group_id); + } + } + + // Add disable_geoip if provided + if let Some(disable_geo) = disable_geoip { + if disable_geo { + event.insert_prop("$geoip_disable", true)?; + } + } + + // Capture the event + self.capture(event)?; + + // Mark as reported (even if capture failed to avoid retry storms) + { + let mut reported = self.distinct_ids_feature_flags_reported.write().unwrap(); + + // Check size limit and evict if necessary + if reported.len() >= MAX_DICT_SIZE && !reported.contains_key(distinct_id) { + // Remove first entry (FIFO eviction) + if let Some(first_key) = reported.keys().next().cloned() { + reported.remove(&first_key); + } + } + + reported + .entry(distinct_id.to_string()) + .or_insert_with(HashSet::new) + .insert(feature_flag_reported_key); + } Ok(()) } + + /// Get a specific feature flag value for a user + /// + /// This method will automatically capture a `$feature_flag_called` event + /// unless `send_feature_flag_events` is disabled in ClientOptions. + pub fn get_feature_flag, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result, Error> { + self.get_feature_flag_with_options( + key, + distinct_id, + groups, + person_properties, + group_properties, + true, // send_feature_flag_events + ) + } + + /// Get a specific feature flag value with control over event capture + pub(crate) fn get_feature_flag_with_options, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + send_feature_flag_events: bool, + ) -> Result, Error> { + let key = key.into(); + let distinct_id = distinct_id.into(); + let mut locally_evaluated = false; + let mut flag_value: Option = None; + let mut payload: Option = None; + let mut request_id: Option = None; + let mut flag_details_map: HashMap = HashMap::new(); + + // Try local evaluation first if available + if let Some(ref evaluator) = self.local_evaluator { + let props = person_properties.clone().unwrap_or_default(); + match evaluator.evaluate_flag(&key, &distinct_id, &props) { + Ok(Some(value)) => { + locally_evaluated = true; + flag_value = Some(value); + } + Ok(None) => { + // Flag not found locally, fall through to API + } + Err(_e) => { + // Inconclusive match, fall through to API + } + } + } + + // Fall back to API if not locally evaluated + if flag_value.is_none() { + match self.get_feature_flags( + distinct_id.clone(), + groups.clone(), + person_properties, + group_properties, + ) { + Ok(response) => { + // Use helper methods to get flag value and payload + flag_value = response.get_flag_value(&key); + payload = response.get_flag_payload(&key); + request_id = response.request_id; + flag_details_map = response.flags; + } + Err(_e) => { + // Return None on error (graceful degradation) + flag_value = None; + } + } + } + + // Capture $feature_flag_called event if enabled + if self.options.send_feature_flag_events && send_feature_flag_events { + let _ = self.capture_feature_flag_called( + &distinct_id, + &key, + flag_value.as_ref(), + payload, + locally_evaluated, + groups, + None, // disable_geoip - not yet supported + request_id, + flag_details_map.get(&key), + ); + } + + Ok(flag_value) + } + + /// Check if a feature flag is enabled for a user + /// + /// This method will automatically capture a `$feature_flag_called` event + /// unless `send_feature_flag_events` is disabled in ClientOptions. + pub fn is_feature_enabled, D: Into>( + &self, + key: K, + distinct_id: D, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result { + let flag_value = self.get_feature_flag( + key, + distinct_id, + groups, + person_properties, + group_properties, + )?; + Ok(match flag_value { + Some(FlagValue::Boolean(b)) => b, + Some(FlagValue::String(_)) => true, // Variants are considered enabled + None => false, + }) + } + + /// Get a feature flag payload for a user + pub(crate) fn get_feature_flag_payload, D: Into>( + &self, + key: K, + distinct_id: D, + ) -> Result, Error> { + let key_str = key.into(); + let flags_endpoint = self.options.endpoints().build_url(Endpoint::Flags); + + let payload = json!({ + "api_key": self.options.api_key, + "distinct_id": distinct_id.into(), + }); + + let response = self + .client + .post(&flags_endpoint) + .header(CONTENT_TYPE, "application/json") + .json(&payload) + .send() + .map_err(TransportError::from)?; + + if !response.status().is_success() { + return Ok(None); + } + + let flags_response: FeatureFlagsResponse = response.json().map_err(|e| { + ValidationError::SerializationFailed(format!("Failed to parse response: {}", e)) + })?; + + Ok(flags_response.get_flag_payload(&key_str)) + } + + /// Get all feature flags as a HashMap + pub fn get_all_flags( + &self, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result, Error> { + let response = + self.get_feature_flags(distinct_id, groups, person_properties, group_properties)?; + Ok(response.to_flag_values()) + } + + /// Get all feature flags and payloads + pub fn get_all_flags_and_payloads( + &self, + distinct_id: String, + groups: Option>, + person_properties: Option>, + group_properties: Option>>, + ) -> Result< + ( + HashMap, + HashMap, + ), + Error, + > { + let response = + self.get_feature_flags(distinct_id, groups, person_properties, group_properties)?; + Ok((response.to_flag_values(), response.to_flag_payloads())) + } + + /// Evaluate a feature flag locally (requires feature flags to be loaded) + pub fn evaluate_feature_flag_locally( + &self, + flag: &FeatureFlag, + distinct_id: &str, + person_properties: &HashMap, + ) -> Result { + match_feature_flag(flag, distinct_id, person_properties) + .map_err(|e| ValidationError::SerializationFailed(e.message).into()) + } } diff --git a/src/client/config.rs b/src/client/config.rs new file mode 100644 index 0000000..3f9c699 --- /dev/null +++ b/src/client/config.rs @@ -0,0 +1,420 @@ +use crate::endpoints::{normalize_endpoint, EndpointManager}; +use crate::error::InitializationError; +use crate::Error; + +/// Configuration options for the PostHog client. +#[derive(Debug, Clone)] +pub struct ClientOptions { + pub(crate) api_key: String, + pub(crate) request_timeout_seconds: u64, + + // Feature flags related fields + pub(crate) personal_api_key: Option, + pub(crate) enable_local_evaluation: bool, + pub(crate) poll_interval_seconds: u64, + pub(crate) feature_flags_request_timeout_seconds: u64, + pub(crate) send_feature_flag_events: bool, + + // Other configuration + pub(crate) gzip: bool, + pub(crate) disabled: bool, + pub(crate) disable_geoip: bool, + + // Endpoint management + pub(crate) endpoint_manager: EndpointManager, +} + +impl ClientOptions { + /// Get the full endpoint URL for single event capture + #[cfg(test)] + pub(crate) fn single_event_endpoint(&self) -> String { + use crate::endpoints::Endpoint; + self.endpoint_manager.build_url(Endpoint::Capture) + } + + /// Get the full endpoint URL for batch event capture + #[cfg(test)] + pub(crate) fn batch_event_endpoint(&self) -> String { + self.endpoint_manager.batch_event_endpoint() + } + + /// Get the endpoint manager + pub(crate) fn endpoints(&self) -> &EndpointManager { + &self.endpoint_manager + } + + /// Check if the client is disabled + pub(crate) fn is_disabled(&self) -> bool { + self.disabled + } +} + +/// Builder for ClientOptions with validation. +pub struct ClientOptionsBuilder { + api_endpoint: Option, + host: Option, + api_key: Option, + request_timeout_seconds: Option, + personal_api_key: Option, + enable_local_evaluation: Option, + poll_interval_seconds: Option, + feature_flags_request_timeout_seconds: Option, + send_feature_flag_events: Option, + gzip: Option, + disabled: Option, + disable_geoip: Option, +} + +impl ClientOptionsBuilder { + /// Create a new ClientOptionsBuilder with default values + pub fn new() -> Self { + Self { + api_endpoint: None, + host: None, + api_key: None, + request_timeout_seconds: None, + personal_api_key: None, + enable_local_evaluation: None, + poll_interval_seconds: None, + feature_flags_request_timeout_seconds: None, + send_feature_flag_events: None, + gzip: None, + disabled: None, + disable_geoip: None, + } + } + + /// Set the API key (required) + pub fn api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Set the host URL (defaults to US ingestion endpoint) + pub fn host(mut self, host: impl Into) -> Self { + self.host = Some(host.into()); + self + } + + /// Set the API endpoint. Accepts either: + /// - A hostname like "https://us.posthog.com" + /// - A full endpoint URL like "https://us.i.posthog.com/i/v0/e/" (for backward compatibility) + /// + /// The SDK will automatically append the appropriate paths (/i/v0/e/ or /batch/) + /// based on the operation being performed. + pub fn api_endpoint(mut self, endpoint: impl Into) -> Self { + self.api_endpoint = Some(endpoint.into()); + self + } + + /// Set the request timeout in seconds (default: 30) + pub fn request_timeout_seconds(mut self, seconds: u64) -> Self { + self.request_timeout_seconds = Some(seconds); + self + } + + /// Set the personal API key for flag definitions (required for local evaluation) + pub fn personal_api_key(mut self, key: impl Into) -> Self { + self.personal_api_key = Some(key.into()); + self + } + + /// Enable local evaluation of feature flags + pub fn enable_local_evaluation(mut self, enable: bool) -> Self { + self.enable_local_evaluation = Some(enable); + self + } + + /// Set the poll interval for flag definitions (default: 30) + pub fn poll_interval_seconds(mut self, seconds: u64) -> Self { + self.poll_interval_seconds = Some(seconds); + self + } + + /// Set the feature flags request timeout (default: 3) + pub fn feature_flags_request_timeout_seconds(mut self, seconds: u64) -> Self { + self.feature_flags_request_timeout_seconds = Some(seconds); + self + } + + /// Enable automatic $feature_flag_called events (default: true) + pub fn send_feature_flag_events(mut self, send: bool) -> Self { + self.send_feature_flag_events = Some(send); + self + } + + /// Enable gzip compression for requests + pub fn gzip(mut self, enable: bool) -> Self { + self.gzip = Some(enable); + self + } + + /// Disable tracking (useful for development) + pub fn disabled(mut self, disable: bool) -> Self { + self.disabled = Some(disable); + self + } + + /// Disable automatic geoip enrichment + pub fn disable_geoip(mut self, disable: bool) -> Self { + self.disable_geoip = Some(disable); + self + } + + /// Build the ClientOptions, validating all fields + pub fn build(self) -> Result { + let api_key = self.api_key.ok_or(InitializationError::MissingApiKey)?; + + // Validate that personal_api_key is provided when local evaluation is enabled + let enable_local_evaluation = self.enable_local_evaluation.unwrap_or(false); + if enable_local_evaluation && self.personal_api_key.is_none() { + return Err(InitializationError::MissingPersonalApiKey.into()); + } + + let request_timeout_seconds = self.request_timeout_seconds.unwrap_or(30); + + // Process the endpoint with correct priority: api_endpoint > host + let endpoint_to_use = self.api_endpoint.or(self.host.clone()); + + // Validate the endpoint if provided + if let Some(ref endpoint) = endpoint_to_use { + normalize_endpoint(endpoint)?; + } + + // Initialize endpoint manager with the prioritized endpoint + let endpoint_manager = EndpointManager::new(endpoint_to_use); + + Ok(ClientOptions { + api_key, + request_timeout_seconds, + personal_api_key: self.personal_api_key, + enable_local_evaluation, + poll_interval_seconds: self.poll_interval_seconds.unwrap_or(30), + feature_flags_request_timeout_seconds: self + .feature_flags_request_timeout_seconds + .unwrap_or(3), + send_feature_flag_events: self.send_feature_flag_events.unwrap_or(true), + gzip: self.gzip.unwrap_or(false), + disabled: self.disabled.unwrap_or(false), + disable_geoip: self.disable_geoip.unwrap_or(false), + endpoint_manager, + }) + } +} + +impl Default for ClientOptionsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From<&str> for ClientOptions { + fn from(api_key: &str) -> Self { + ClientOptionsBuilder::default() + .api_key(api_key.to_string()) + .build() + .expect("We always set the API key, so this is infallible") + } +} + +impl From<(&str, &str)> for ClientOptions { + /// Create options from API key and host + fn from((api_key, host): (&str, &str)) -> Self { + ClientOptionsBuilder::default() + .api_key(api_key.to_string()) + .host(host.to_string()) + .build() + .expect("We always set the API key, so this is infallible") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_options_builder_default_endpoint() { + let options = ClientOptionsBuilder::new() + .api_key("test_key") + .build() + .unwrap(); + + assert_eq!( + options.single_event_endpoint(), + "https://us.i.posthog.com/i/v0/e/" + ); + assert_eq!( + options.batch_event_endpoint(), + "https://us.i.posthog.com/batch/" + ); + } + + #[test] + fn test_client_options_builder_with_hostname() { + let options = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("https://eu.posthog.com") + .build() + .unwrap(); + + // EU PostHog Cloud redirects to EU ingestion endpoint + assert_eq!( + options.single_event_endpoint(), + "https://eu.i.posthog.com/i/v0/e/" + ); + assert_eq!( + options.batch_event_endpoint(), + "https://eu.i.posthog.com/batch/" + ); + } + + #[test] + fn test_client_options_builder_with_full_endpoint_single() { + // Backward compatibility: accept full endpoint and strip path + let options = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("https://us.i.posthog.com/i/v0/e/") + .build() + .unwrap(); + + assert_eq!( + options.single_event_endpoint(), + "https://us.i.posthog.com/i/v0/e/" + ); + assert_eq!( + options.batch_event_endpoint(), + "https://us.i.posthog.com/batch/" + ); + } + + #[test] + fn test_client_options_builder_with_full_endpoint_batch() { + // Backward compatibility: accept batch endpoint and strip path + let options = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("https://us.i.posthog.com/batch/") + .build() + .unwrap(); + + assert_eq!( + options.single_event_endpoint(), + "https://us.i.posthog.com/i/v0/e/" + ); + assert_eq!( + options.batch_event_endpoint(), + "https://us.i.posthog.com/batch/" + ); + } + + #[test] + fn test_client_options_builder_with_port() { + let options = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("http://localhost:8000") + .build() + .unwrap(); + + assert_eq!( + options.single_event_endpoint(), + "http://localhost:8000/i/v0/e/" + ); + assert_eq!( + options.batch_event_endpoint(), + "http://localhost:8000/batch/" + ); + } + + #[test] + fn test_client_options_builder_with_trailing_slash() { + let options = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("https://eu.posthog.com/") + .build() + .unwrap(); + + assert_eq!( + options.single_event_endpoint(), + "https://eu.i.posthog.com/i/v0/e/" + ); + assert_eq!( + options.batch_event_endpoint(), + "https://eu.i.posthog.com/batch/" + ); + } + + #[test] + fn test_client_options_builder_invalid_endpoint_no_scheme() { + let result = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("posthog.com") + .build(); + + assert!(result.is_err()); + match result.unwrap_err() { + Error::Initialization(InitializationError::InvalidEndpoint(msg)) => { + assert!(msg.contains("Endpoint must start with http://")); + } + _ => panic!("Expected InvalidEndpoint error"), + } + } + + #[test] + fn test_client_options_builder_invalid_endpoint_malformed() { + let result = ClientOptionsBuilder::new() + .api_key("test_key") + .api_endpoint("not a url") + .build(); + + assert!(result.is_err()); + match result.unwrap_err() { + Error::Initialization(InitializationError::InvalidEndpoint(msg)) => { + // Should contain error about scheme or being invalid + assert!(msg.contains("http://") || msg.contains("https://")); + } + _ => panic!("Expected InvalidEndpoint error"), + } + } + + #[test] + fn test_client_options_builder_missing_api_key() { + let result = ClientOptionsBuilder::new().build(); + + assert!(result.is_err()); + match result.unwrap_err() { + Error::Initialization(InitializationError::MissingApiKey) => { + // Correct error type + } + _ => panic!("Expected MissingApiKey error"), + } + } + + #[test] + fn test_client_options_builder_local_evaluation_without_personal_key() { + let result = ClientOptionsBuilder::new() + .api_key("test_key") + .enable_local_evaluation(true) + .build(); + + assert!(result.is_err()); + match result.unwrap_err() { + Error::Initialization(InitializationError::MissingPersonalApiKey) => { + // Correct error type + } + _ => panic!("Expected MissingPersonalApiKey error"), + } + } + + #[test] + fn test_client_options_builder_local_evaluation_with_personal_key() { + let result = ClientOptionsBuilder::new() + .api_key("test_key") + .personal_api_key("personal_key") + .enable_local_evaluation(true) + .build(); + + assert!(result.is_ok()); + let options = result.unwrap(); + assert!(options.enable_local_evaluation); + assert_eq!(options.personal_api_key, Some("personal_key".to_string())); + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index a60d772..092a047 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,5 +1,4 @@ -use crate::API_ENDPOINT; -use derive_builder::Builder; +mod config; #[cfg(not(feature = "async-client"))] mod blocking; @@ -15,21 +14,5 @@ pub use async_client::client; #[cfg(feature = "async-client")] pub use async_client::Client; -#[derive(Builder)] -pub struct ClientOptions { - #[builder(default = "API_ENDPOINT.to_string()")] - api_endpoint: String, - api_key: String, - - #[builder(default = "30")] - request_timeout_seconds: u64, -} - -impl From<&str> for ClientOptions { - fn from(api_key: &str) -> Self { - ClientOptionsBuilder::default() - .api_key(api_key.to_string()) - .build() - .expect("We always set the API key, so this is infallible") - } -} +// Re-export configuration types +pub use config::{ClientOptions, ClientOptionsBuilder}; diff --git a/src/endpoints.rs b/src/endpoints.rs new file mode 100644 index 0000000..15484cf --- /dev/null +++ b/src/endpoints.rs @@ -0,0 +1,205 @@ +use crate::error::InitializationError; +use crate::Error; +use std::fmt; + +/// US ingestion endpoint +pub const US_INGESTION_ENDPOINT: &str = "https://us.i.posthog.com"; + +/// EU ingestion endpoint +pub const EU_INGESTION_ENDPOINT: &str = "https://eu.i.posthog.com"; + +/// Default host (US by default) +pub const DEFAULT_HOST: &str = US_INGESTION_ENDPOINT; + +/// API endpoints for different operations +#[derive(Debug, Clone)] +pub(crate) enum Endpoint { + /// Event capture endpoint + Capture, + /// Feature flags endpoint + Flags, + /// Local evaluation endpoint + LocalEvaluation, +} + +impl Endpoint { + /// Get the path for this endpoint + pub fn path(&self) -> &str { + match self { + Endpoint::Capture => "/i/v0/e/", + Endpoint::Flags => "/flags/?v=2", + Endpoint::LocalEvaluation => "/api/feature_flag/local_evaluation/", + } + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.path()) + } +} + +/// Normalize an endpoint to a base URL. +/// Accepts both hostnames (https://us.posthog.com) and full endpoints (https://us.i.posthog.com/i/v0/e/) +pub(crate) fn normalize_endpoint(endpoint: &str) -> Result { + let endpoint = endpoint.trim(); + + // Basic validation - must start with http:// or https:// + if !endpoint.starts_with("http://") && !endpoint.starts_with("https://") { + return Err(InitializationError::InvalidEndpoint( + "Endpoint must start with http:// or https://".to_string(), + ) + .into()); + } + + // Parse as URL to validate + let url = endpoint + .parse::() + .map_err(|e| InitializationError::InvalidEndpoint(format!("Invalid URL: {}", e)))?; + + // Extract scheme and host + let scheme = url.scheme(); + let host = url + .host_str() + .ok_or_else(|| InitializationError::InvalidEndpoint("Missing host".to_string()))?; + + // Check if this looks like a full endpoint path (contains /i/v0/e or /batch) + let path = url.path(); + if path.contains("/i/v0/e") || path.contains("/batch") { + // Strip the path, keep only scheme://host:port + let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default(); + Ok(format!("{}://{}{}", scheme, host, port)) + } else { + // Already a base URL, just reconstruct it cleanly + let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default(); + Ok(format!("{}://{}{}", scheme, host, port)) + } +} + +/// Manages PostHog API endpoints and host configuration +#[derive(Debug, Clone)] +pub(crate) struct EndpointManager { + base_host: String, +} + +impl EndpointManager { + /// Create a new endpoint manager with the given host + pub fn new(host: Option) -> Self { + // Normalize the host if provided (strips paths from full endpoint URLs) + let normalized_host = host.and_then(|h| normalize_endpoint(&h).ok()); + + let base_host = Self::determine_server_host(normalized_host); + + Self { base_host } + } + + /// Determine the actual server host based on the provided host + pub fn determine_server_host(host: Option) -> String { + let host_or_default = host.unwrap_or_else(|| DEFAULT_HOST.to_string()); + let trimmed_host = host_or_default.trim_end_matches('/'); + + match trimmed_host { + "https://app.posthog.com" | "https://us.posthog.com" => { + US_INGESTION_ENDPOINT.to_string() + } + "https://eu.posthog.com" => EU_INGESTION_ENDPOINT.to_string(), + _ => host_or_default, + } + } + + /// Build a full URL for a given endpoint + pub fn build_url(&self, endpoint: Endpoint) -> String { + format!( + "{}{}", + self.base_host.trim_end_matches('/'), + endpoint.path() + ) + } + + /// Build a URL with a custom path + pub fn build_custom_url(&self, path: &str) -> String { + let normalized_path = if path.starts_with('/') { + path.to_string() + } else { + format!("/{path}") + }; + format!( + "{}{}", + self.base_host.trim_end_matches('/'), + normalized_path + ) + } + + /// Get the base host for API operations (without the path) + pub fn api_host(&self) -> String { + self.base_host.trim_end_matches('/').to_string() + } + + /// Get the batch event capture endpoint URL (legacy, uses same endpoint as single) + pub fn batch_event_endpoint(&self) -> String { + self.build_custom_url("/batch/") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_determine_server_host() { + assert_eq!( + EndpointManager::determine_server_host(None), + US_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://app.posthog.com".to_string())), + US_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://us.posthog.com".to_string())), + US_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://eu.posthog.com".to_string())), + EU_INGESTION_ENDPOINT + ); + + assert_eq!( + EndpointManager::determine_server_host(Some("https://custom.domain.com".to_string())), + "https://custom.domain.com" + ); + } + + #[test] + fn test_build_url() { + let manager = EndpointManager::new(None); + + assert_eq!( + manager.build_url(Endpoint::Capture), + format!("{}/i/v0/e/", US_INGESTION_ENDPOINT) + ); + + assert_eq!( + manager.build_url(Endpoint::Flags), + format!("{}/flags/?v=2", US_INGESTION_ENDPOINT) + ); + } + + #[test] + fn test_build_custom_url() { + let manager = EndpointManager::new(Some("https://custom.com/".to_string())); + + assert_eq!( + manager.build_custom_url("/api/test"), + "https://custom.com/api/test" + ); + + assert_eq!( + manager.build_custom_url("api/test"), + "https://custom.com/api/test" + ); + } +} diff --git a/src/error.rs b/src/error.rs index e193a85..b4e61af 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,23 +1,207 @@ -use std::fmt::{Display, Formatter}; +use std::time::Duration; +use thiserror::Error; -impl Display for Error { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { +/// Main error type for the PostHog Rust SDK. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum Error { + // Deprecated variants - kept for backward compatibility + #[deprecated(since = "0.4.0", note = "Use Error::Connection instead")] + #[error("Connection error: {0}")] + Connection(String), + + #[deprecated(since = "0.4.0", note = "Use Error::Validation instead")] + #[error("Serialization error: {0}")] + Serialization(String), + + #[deprecated( + since = "0.4.0", + note = "Use Error::Initialization(InitializationError::AlreadyInitialized) instead" + )] + #[error("Global client already initialized")] + AlreadyInitialized, + + #[deprecated( + since = "0.4.0", + note = "Use Error::Initialization(InitializationError::NotInitialized) instead" + )] + #[error("Global client not initialized")] + NotInitialized, + + #[deprecated(since = "0.4.0", note = "Use Error::Validation instead")] + #[error("Invalid timestamp: {0}")] + InvalidTimestamp(String), + + #[deprecated(since = "0.4.0", note = "Use Error::Initialization instead")] + #[error("Uninitialized field: {0}")] + UninitializedField(&'static str), + + #[deprecated(since = "0.4.0", note = "Use Error::Validation instead")] + #[error("Validation error: {0}")] + ValidationError(String), + + // New error categories + /// Transport-layer errors (network, HTTP, etc.) + #[error(transparent)] + Transport(#[from] TransportError), + + /// Validation errors for events and data + #[error(transparent)] + Validation(#[from] ValidationError), + + /// Initialization and configuration errors + #[error(transparent)] + Initialization(#[from] InitializationError), +} + +impl Error { + /// Returns true if this error can be retried. + pub fn is_retryable(&self) -> bool { + match self { + Error::Transport(e) => e.is_retryable(), + _ => false, + } + } + + /// Returns true if this error is due to invalid usage (4xx, validation, or config errors). + pub fn is_client_error(&self) -> bool { match self { - Error::Connection(msg) => write!(f, "Connection Error: {}", msg), - Error::Serialization(msg) => write!(f, "Serialization Error: {}", msg), - Error::AlreadyInitialized => write!(f, "Client already initialized"), - Error::NotInitialized => write!(f, "Client not initialized"), - Error::InvalidTimestamp(msg) => write!(f, "Invalid Timestamp: {}", msg), + Error::Validation(_) | Error::Initialization(_) => true, + Error::Transport(e) => e.is_client_error(), + _ => false, } } } -#[derive(Debug)] +/// Network transport and HTTP errors. +#[derive(Debug, Error)] #[non_exhaustive] -pub enum Error { - Connection(String), - Serialization(String), +pub enum TransportError { + /// The request timed out after the specified duration + #[error("Request timed out after {0:?}")] + Timeout(Duration), + + /// DNS resolution failed for the hostname + #[error("DNS resolution failed: {0}")] + DnsResolution(String), + + /// Network is unreachable + #[error("Network is unreachable")] + NetworkUnreachable, + + /// HTTP error with status code and message + #[error("HTTP error {0}: {1}")] + HttpError(u16, String), + + /// TLS/SSL error + #[error("TLS error: {0}")] + TlsError(String), +} + +impl TransportError { + /// Returns true if this error can be retried (timeouts, 5xx, 429). + pub fn is_retryable(&self) -> bool { + match self { + TransportError::Timeout(_) => true, + TransportError::NetworkUnreachable => true, + TransportError::HttpError(status, _) => { + // Retry on 5xx errors and 429 (rate limit) + (*status >= 500 && *status < 600) || *status == 429 + } + _ => false, + } + } + + fn is_client_error(&self) -> bool { + matches!(self, TransportError::HttpError(400..=499, _)) + } +} + +impl From for TransportError { + fn from(err: reqwest::Error) -> Self { + if err.is_timeout() { + return TransportError::Timeout(Duration::from_secs(30)); + } + + if err.is_connect() { + return err + .url() + .and_then(|u| u.host_str()) + .map(|host| TransportError::DnsResolution(host.to_string())) + .unwrap_or(TransportError::NetworkUnreachable); + } + + if let Some(status) = err.status() { + return TransportError::HttpError(status.as_u16(), err.to_string()); + } + + let err_str = err.to_string(); + if err_str.contains("tls") || err_str.contains("ssl") { + TransportError::TlsError(err_str) + } else { + TransportError::NetworkUnreachable + } + } +} + +/// Event validation and data integrity errors. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum ValidationError { + /// Event property value is too large + #[error("Property '{key}' is too large ({size} bytes)")] + PropertyTooLarge { key: String, size: usize }, + + /// Event property has invalid type + #[error("Property '{key}' has invalid type (expected {expected})")] + InvalidPropertyType { key: String, expected: String }, + + /// Timestamp is invalid (e.g., in the future) + #[error("Invalid timestamp: {0}")] + InvalidTimestamp(String), + + /// Distinct ID is invalid or empty + #[error("Invalid distinct_id: {0}")] + InvalidDistinctId(String), + + /// Batch size exceeds maximum allowed + #[error("Batch size {size} exceeds maximum {max}")] + BatchSizeExceeded { size: usize, max: usize }, + + /// Event name is too long + #[error("Event name is too long ({length} chars, max {max})")] + EventNameTooLong { length: usize, max: usize }, + + /// JSON serialization failed (should rarely happen if validation is correct) + #[error("Serialization failed: {0}")] + SerializationFailed(String), +} + +/// Client initialization and configuration errors. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum InitializationError { + /// API key is missing or empty + #[error("API key is missing or empty")] + MissingApiKey, + + /// API endpoint URL is invalid + #[error("Invalid endpoint: {0}")] + InvalidEndpoint(String), + + /// Timeout value is invalid + #[error("Invalid timeout: {0:?}")] + InvalidTimeout(Duration), + + /// Global client is already initialized + #[error("Global client is already initialized")] AlreadyInitialized, + + /// Global client is not initialized + #[error("Global client is not initialized")] NotInitialized, - InvalidTimestamp(String), + + /// Personal API key is required when local evaluation is enabled + #[error("Personal API key is required when enable_local_evaluation is true")] + MissingPersonalApiKey, } diff --git a/src/event.rs b/src/event.rs index 9b406b9..9aa8651 100644 --- a/src/event.rs +++ b/src/event.rs @@ -5,6 +5,7 @@ use semver::Version; use serde::Serialize; use uuid::Uuid; +use crate::error::ValidationError; use crate::Error; /// An [`Event`] represents an interaction a user has with your app or @@ -23,7 +24,7 @@ pub struct Event { impl Event { /// Capture a new identified [`Event`]. Unless you have a distinct ID you can - /// associate with a user, you probably want to use [`new_anon`] instead. + /// associate with a user, you probably want to use [`Event::new_anon`] instead. pub fn new>(event: S, distinct_id: S) -> Self { Self { event: event.into(), @@ -35,7 +36,7 @@ impl Event { } /// Capture a new anonymous event. - /// See https://posthog.com/docs/data/anonymous-vs-identified-events#how-to-capture-anonymous-events + /// See pub fn new_anon>(event: S) -> Self { let mut res = Self { event: event.into(), @@ -57,13 +58,13 @@ impl Event { key: K, prop: P, ) -> Result<(), Error> { - let as_json = - serde_json::to_value(prop).map_err(|e| Error::Serialization(e.to_string()))?; + let as_json = serde_json::to_value(prop) + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; let _ = self.properties.insert(key.into(), as_json); Ok(()) } - /// Capture this as a group event. See https://posthog.com/docs/product-analytics/group-analytics#how-to-capture-group-events + /// Capture this as a group event. See /// Note that group events cannot be personless, and will be automatically upgraded to include person profile processing if /// they were anonymous. This might lead to "empty" person profiles being created. pub fn add_group(&mut self, group_name: &str, group_id: &str) { @@ -81,9 +82,10 @@ impl Event { Tz: TimeZone, { if timestamp > Utc::now() + Duration::seconds(1) { - return Err(Error::InvalidTimestamp(String::from( + return Err(ValidationError::InvalidTimestamp(String::from( "Events cannot occur in the future", - ))); + )) + .into()); } self.timestamp = Some(timestamp.naive_utc()); Ok(()) diff --git a/src/feature_flags.rs b/src/feature_flags.rs new file mode 100644 index 0000000..bd03656 --- /dev/null +++ b/src/feature_flags.rs @@ -0,0 +1,697 @@ +use serde::{Deserialize, Serialize}; +use sha1::{Digest, Sha1}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum FlagValue { + Boolean(bool), + String(String), +} + +#[derive(Debug)] +pub(crate) struct InconclusiveMatchError { + pub message: String, +} + +impl InconclusiveMatchError { + pub fn new(message: &str) -> Self { + Self { + message: message.to_string(), + } + } +} + +impl Default for FlagValue { + fn default() -> Self { + FlagValue::Boolean(false) + } +} + +impl std::fmt::Display for FlagValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FlagValue::Boolean(b) => write!(f, "{}", b), + FlagValue::String(s) => write!(f, "{}", s), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeatureFlag { + pub key: String, + pub active: bool, + #[serde(default)] + pub(crate) filters: FeatureFlagFilters, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub(crate) struct FeatureFlagFilters { + #[serde(default)] + pub groups: Vec, + #[serde(default)] + pub multivariate: Option, + #[serde(default)] + pub payloads: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct FeatureFlagCondition { + #[serde(default)] + pub properties: Vec, + pub rollout_percentage: Option, + pub variant: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Property { + pub key: String, + pub value: serde_json::Value, + #[serde(default = "default_operator")] + pub operator: String, + #[serde(rename = "type")] + pub property_type: Option, +} + +fn default_operator() -> String { + "exact".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub(crate) struct MultivariateFilter { + pub variants: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct MultivariateVariant { + pub key: String, + pub rollout_percentage: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeatureFlagsResponse { + pub flags: HashMap, + #[serde(rename = "errorsWhileComputingFlags")] + #[serde(default)] + pub errors_while_computing_flags: bool, + #[serde(rename = "requestId")] + #[serde(default)] + pub request_id: Option, +} + +impl FeatureFlagsResponse { + /// Get the flag value for a specific flag key + pub fn get_flag_value(&self, key: &str) -> Option { + self.flags.get(key).map(|d| d.to_flag_value()) + } + + /// Get the payload for a specific flag key + pub fn get_flag_payload(&self, key: &str) -> Option { + self.flags.get(key).and_then(|d| d.payload()) + } + + /// Get all flag values as a HashMap + pub fn to_flag_values(&self) -> HashMap { + self.flags + .iter() + .map(|(key, detail)| (key.clone(), detail.to_flag_value())) + .collect() + } + + /// Get all flag payloads as a HashMap + pub fn to_flag_payloads(&self) -> HashMap { + self.flags + .iter() + .filter_map(|(key, detail)| detail.payload().map(|p| (key.clone(), p))) + .collect() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlagDetail { + pub key: String, + pub enabled: bool, + pub variant: Option, + #[serde(default)] + pub reason: Option, + #[serde(default)] + pub metadata: Option, +} + +impl FlagDetail { + /// Convert FlagDetail to FlagValue + pub fn to_flag_value(&self) -> FlagValue { + if self.enabled { + if let Some(ref variant) = self.variant { + FlagValue::String(variant.clone()) + } else { + FlagValue::Boolean(true) + } + } else { + FlagValue::Boolean(false) + } + } + + /// Get the payload from metadata if present + pub fn payload(&self) -> Option { + self.metadata.as_ref().and_then(|m| m.payload.clone()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlagReason { + pub code: String, + #[serde(default)] + pub condition_index: Option, + #[serde(default)] + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlagMetadata { + pub id: u64, + pub version: u32, + pub description: Option, + pub payload: Option, +} + +const LONG_SCALE: f64 = 0xFFFFFFFFFFFFFFFu64 as f64; // Must be exactly 15 F's for hash compatibility + +pub fn hash_key(key: &str, distinct_id: &str, salt: &str) -> f64 { + let hash_key = format!("{key}.{distinct_id}{salt}"); + let mut hasher = Sha1::new(); + hasher.update(hash_key.as_bytes()); + let result = hasher.finalize(); + let hex_str = format!("{result:x}"); + let hash_val = u64::from_str_radix(&hex_str[..15], 16).unwrap_or(0); + hash_val as f64 / LONG_SCALE +} + +pub fn get_matching_variant(flag: &FeatureFlag, distinct_id: &str) -> Option { + let hash_value = hash_key(&flag.key, distinct_id, "variant"); + let variants = flag.filters.multivariate.as_ref()?.variants.as_slice(); + + let mut value_min = 0.0; + for variant in variants { + let value_max = value_min + variant.rollout_percentage / 100.0; + if hash_value >= value_min && hash_value < value_max { + return Some(variant.key.clone()); + } + value_min = value_max; + } + None +} + +pub(crate) fn match_feature_flag( + flag: &FeatureFlag, + distinct_id: &str, + properties: &HashMap, +) -> Result { + if !flag.active { + return Ok(FlagValue::Boolean(false)); + } + + let conditions = &flag.filters.groups; + + // Sort conditions to evaluate variant overrides first + let mut sorted_conditions = conditions.clone(); + sorted_conditions.sort_by_key(|c| if c.variant.is_some() { 0 } else { 1 }); + + let mut is_inconclusive = false; + + for condition in sorted_conditions { + match is_condition_match(flag, distinct_id, &condition, properties) { + Ok(true) => { + if let Some(variant_override) = &condition.variant { + // Check if variant is valid + if let Some(ref multivariate) = flag.filters.multivariate { + let valid_variants: Vec = multivariate + .variants + .iter() + .map(|v| v.key.clone()) + .collect(); + + if valid_variants.contains(variant_override) { + return Ok(FlagValue::String(variant_override.clone())); + } + } + } + + // Try to get matching variant or return true + if let Some(variant) = get_matching_variant(flag, distinct_id) { + return Ok(FlagValue::String(variant)); + } + return Ok(FlagValue::Boolean(true)); + } + Ok(false) => continue, + Err(_) => { + is_inconclusive = true; + } + } + } + + if is_inconclusive { + return Err(InconclusiveMatchError::new( + "Can't determine if feature flag is enabled or not with given properties", + )); + } + + Ok(FlagValue::Boolean(false)) +} + +fn is_condition_match( + flag: &FeatureFlag, + distinct_id: &str, + condition: &FeatureFlagCondition, + properties: &HashMap, +) -> Result { + // Check properties first + for prop in &condition.properties { + if !match_property(prop, properties)? { + return Ok(false); + } + } + + // If all properties match (or no properties), check rollout percentage + if let Some(rollout_percentage) = condition.rollout_percentage { + let hash_value = hash_key(&flag.key, distinct_id, ""); + if hash_value > (rollout_percentage / 100.0) { + return Ok(false); + } + } + + Ok(true) +} + +fn match_property( + property: &Property, + properties: &HashMap, +) -> Result { + let value = match properties.get(&property.key) { + Some(v) => v, + None => { + // Handle is_not_set operator + if property.operator == "is_not_set" { + return Ok(true); + } + // Handle is_set operator + if property.operator == "is_set" { + return Ok(false); + } + // For other operators, missing property is inconclusive + return Err(InconclusiveMatchError::new(&format!( + "Property '{}' not found in provided properties", + property.key + ))); + } + }; + + Ok(match property.operator.as_str() { + "exact" => { + if property.value.is_array() { + if let Some(arr) = property.value.as_array() { + for val in arr { + if compare_values(val, value) { + return Ok(true); + } + } + return Ok(false); + } + } + compare_values(&property.value, value) + } + "is_not" => { + if property.value.is_array() { + if let Some(arr) = property.value.as_array() { + for val in arr { + if compare_values(val, value) { + return Ok(false); + } + } + return Ok(true); + } + } + !compare_values(&property.value, value) + } + "is_set" => true, // We already know the property exists + "is_not_set" => false, // We already know the property exists + "icontains" => { + let prop_str = value_to_string(value); + let search_str = value_to_string(&property.value); + prop_str.to_lowercase().contains(&search_str.to_lowercase()) + } + "not_icontains" => { + let prop_str = value_to_string(value); + let search_str = value_to_string(&property.value); + !prop_str.to_lowercase().contains(&search_str.to_lowercase()) + } + "regex" => { + let prop_str = value_to_string(value); + let regex_str = value_to_string(&property.value); + match regex::Regex::new(®ex_str) { + Ok(re) => re.is_match(&prop_str), + Err(_) => false, + } + } + "not_regex" => { + let prop_str = value_to_string(value); + let regex_str = value_to_string(&property.value); + match regex::Regex::new(®ex_str) { + Ok(re) => !re.is_match(&prop_str), + Err(_) => true, + } + } + "gt" | "gte" | "lt" | "lte" => compare_numeric(&property.operator, &property.value, value), + _ => false, + }) +} + +fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> bool { + // Case-insensitive string comparison + if let (Some(a_str), Some(b_str)) = (a.as_str(), b.as_str()) { + return a_str.eq_ignore_ascii_case(b_str); + } + + // Direct comparison for other types + a == b +} + +fn value_to_string(value: &serde_json::Value) -> String { + match value { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Number(n) => n.to_string(), + serde_json::Value::Bool(b) => b.to_string(), + _ => value.to_string(), + } +} + +fn compare_numeric( + operator: &str, + property_value: &serde_json::Value, + value: &serde_json::Value, +) -> bool { + let prop_num = match property_value { + serde_json::Value::Number(n) => n.as_f64(), + serde_json::Value::String(s) => s.parse::().ok(), + _ => None, + }; + + let val_num = match value { + serde_json::Value::Number(n) => n.as_f64(), + serde_json::Value::String(s) => s.parse::().ok(), + _ => None, + }; + + if let (Some(prop), Some(val)) = (prop_num, val_num) { + match operator { + "gt" => val > prop, + "gte" => val >= prop, + "lt" => val < prop, + "lte" => val <= prop, + _ => false, + } + } else { + // Fall back to string comparison + let prop_str = value_to_string(property_value); + let val_str = value_to_string(value); + match operator { + "gt" => val_str > prop_str, + "gte" => val_str >= prop_str, + "lt" => val_str < prop_str, + "lte" => val_str <= prop_str, + _ => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_hash_key() { + let hash = hash_key("test-flag", "user-123", ""); + assert!(hash >= 0.0 && hash <= 1.0); + + // Same inputs should produce same hash + let hash2 = hash_key("test-flag", "user-123", ""); + assert_eq!(hash, hash2); + + // Different inputs should produce different hash + let hash3 = hash_key("test-flag", "user-456", ""); + assert_ne!(hash, hash3); + } + + #[test] + fn test_simple_flag_match() { + let flag = FeatureFlag { + key: "test-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + assert_eq!(result, FlagValue::Boolean(true)); + } + + #[test] + fn test_property_matching() { + let prop = Property { + key: "country".to_string(), + value: json!("US"), + operator: "exact".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("country".to_string(), json!("US")); + + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("country".to_string(), json!("UK")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_multivariate_variants() { + let flag = FeatureFlag { + key: "test-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: Some(MultivariateFilter { + variants: vec![ + MultivariateVariant { + key: "control".to_string(), + rollout_percentage: 50.0, + }, + MultivariateVariant { + key: "test".to_string(), + rollout_percentage: 50.0, + }, + ], + }), + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + + match result { + FlagValue::String(variant) => { + assert!(variant == "control" || variant == "test"); + } + _ => panic!("Expected string variant"), + } + } + + #[test] + fn test_inactive_flag() { + let flag = FeatureFlag { + key: "inactive-flag".to_string(), + active: false, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + assert_eq!(result, FlagValue::Boolean(false)); + } + + #[test] + fn test_rollout_percentage() { + let flag = FeatureFlag { + key: "rollout-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(30.0), // 30% rollout + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + + // Test with multiple users to ensure distribution + let mut enabled_count = 0; + for i in 0..1000 { + let result = match_feature_flag(&flag, &format!("user-{}", i), &properties).unwrap(); + if result == FlagValue::Boolean(true) { + enabled_count += 1; + } + } + + // Should be roughly 30% enabled (allow for some variance) + assert!(enabled_count > 250 && enabled_count < 350); + } + + #[test] + fn test_regex_operator() { + let prop = Property { + key: "email".to_string(), + value: json!(".*@company\\.com$"), + operator: "regex".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("email".to_string(), json!("user@company.com")); + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("email".to_string(), json!("user@example.com")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_icontains_operator() { + let prop = Property { + key: "name".to_string(), + value: json!("ADMIN"), + operator: "icontains".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("name".to_string(), json!("admin_user")); + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("name".to_string(), json!("regular_user")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_numeric_operators() { + // Greater than + let prop_gt = Property { + key: "age".to_string(), + value: json!(18), + operator: "gt".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("age".to_string(), json!(25)); + assert!(match_property(&prop_gt, &properties).unwrap()); + + properties.insert("age".to_string(), json!(15)); + assert!(!match_property(&prop_gt, &properties).unwrap()); + + // Less than or equal + let prop_lte = Property { + key: "score".to_string(), + value: json!(100), + operator: "lte".to_string(), + property_type: None, + }; + + properties.insert("score".to_string(), json!(100)); + assert!(match_property(&prop_lte, &properties).unwrap()); + + properties.insert("score".to_string(), json!(101)); + assert!(!match_property(&prop_lte, &properties).unwrap()); + } + + #[test] + fn test_is_set_operator() { + let prop = Property { + key: "email".to_string(), + value: json!(true), + operator: "is_set".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + properties.insert("email".to_string(), json!("test@example.com")); + assert!(match_property(&prop, &properties).unwrap()); + + properties.remove("email"); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_is_not_set_operator() { + let prop = Property { + key: "phone".to_string(), + value: json!(true), + operator: "is_not_set".to_string(), + property_type: None, + }; + + let mut properties = HashMap::new(); + assert!(match_property(&prop, &properties).unwrap()); + + properties.insert("phone".to_string(), json!("+1234567890")); + assert!(!match_property(&prop, &properties).unwrap()); + } + + #[test] + fn test_empty_groups() { + let flag = FeatureFlag { + key: "empty-groups".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + let properties = HashMap::new(); + let result = match_feature_flag(&flag, "user-123", &properties).unwrap(); + assert_eq!(result, FlagValue::Boolean(false)); + } + + #[test] + fn test_hash_scale_constant() { + // Verify the constant is exactly 15 F's (not 16) + assert_eq!(LONG_SCALE, 0xFFFFFFFFFFFFFFFu64 as f64); + assert_ne!(LONG_SCALE, 0xFFFFFFFFFFFFFFFFu64 as f64); + } +} diff --git a/src/global.rs b/src/global.rs index 3ffa9b9..97a002b 100644 --- a/src/global.rs +++ b/src/global.rs @@ -1,5 +1,6 @@ use std::sync::OnceLock; +use crate::error::InitializationError; use crate::{client, Client, ClientOptions, Error, Event}; static GLOBAL_CLIENT: OnceLock = OnceLock::new(); @@ -19,7 +20,7 @@ pub async fn init_global_client>(options: C) -> Result<() let client = client(options).await; GLOBAL_CLIENT .set(client) - .map_err(|_| Error::AlreadyInitialized) + .map_err(|_| InitializationError::AlreadyInitialized.into()) } /// [`init_global_client`] will initialize a globally available client singleton. This singleton @@ -36,7 +37,7 @@ pub fn init_global_client>(options: C) -> Result<(), Erro let client = client(options); GLOBAL_CLIENT .set(client) - .map_err(|_| Error::AlreadyInitialized) + .map_err(|_| InitializationError::AlreadyInitialized.into()) } /// [`disable`] prevents the global client from being initialized. @@ -55,13 +56,17 @@ pub fn is_disabled() -> bool { /// Capture the provided event, sending it to PostHog using the global client. #[cfg(feature = "async-client")] pub async fn capture(event: Event) -> Result<(), Error> { - let client = GLOBAL_CLIENT.get().ok_or(Error::NotInitialized)?; + let client = GLOBAL_CLIENT + .get() + .ok_or(InitializationError::NotInitialized)?; client.capture(event).await } /// Capture the provided event, sending it to PostHog using the global client. #[cfg(not(feature = "async-client"))] pub fn capture(event: Event) -> Result<(), Error> { - let client = GLOBAL_CLIENT.get().ok_or(Error::NotInitialized)?; + let client = GLOBAL_CLIENT + .get() + .ok_or(InitializationError::NotInitialized)?; client.capture(event) } diff --git a/src/lib.rs b/src/lib.rs index 6fe46d5..ac20eea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,10 @@ mod client; +mod endpoints; mod error; mod event; +mod feature_flags; mod global; - -const API_ENDPOINT: &str = "https://us.i.posthog.com/i/v0/e/"; +mod local_evaluation; // Public interface - any change to this is breaking! // Client @@ -11,14 +12,24 @@ pub use client::client; pub use client::Client; pub use client::ClientOptions; pub use client::ClientOptionsBuilder; -pub use client::ClientOptionsBuilderError; + +// Endpoints +pub use endpoints::{DEFAULT_HOST, EU_INGESTION_ENDPOINT, US_INGESTION_ENDPOINT}; // Error pub use error::Error; +pub use error::InitializationError; +pub use error::TransportError; +pub use error::ValidationError; +// for backward compatibility +pub use error::Error as ClientOptionsBuilderError; // Event pub use event::Event; +// Feature Flags +pub use feature_flags::{FeatureFlag, FlagDetail, FlagMetadata, FlagReason, FlagValue}; + // We expose a global capture function as a convenience, that uses a global client pub use global::capture; pub use global::disable as disable_global; diff --git a/src/local_evaluation.rs b/src/local_evaluation.rs new file mode 100644 index 0000000..035f9b4 --- /dev/null +++ b/src/local_evaluation.rs @@ -0,0 +1,540 @@ +use crate::endpoints::Endpoint; +use crate::error::{TransportError, ValidationError}; +use crate::feature_flags::{match_feature_flag, FeatureFlag, FlagValue, InconclusiveMatchError}; +use crate::Error; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; +#[cfg(not(feature = "async-client"))] +use std::sync::Mutex; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct LocalEvaluationResponse { + pub flags: Vec, + #[serde(default)] + pub group_type_mapping: HashMap, + #[serde(default)] + pub cohorts: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Cohort { + pub id: String, + pub name: String, + pub properties: serde_json::Value, +} + +/// Manages locally cached feature flags for evaluation +#[derive(Clone)] +pub(crate) struct FlagCache { + flags: Arc>>, + group_type_mapping: Arc>>, + cohorts: Arc>>, +} + +impl Default for FlagCache { + fn default() -> Self { + Self::new() + } +} + +impl FlagCache { + pub fn new() -> Self { + Self { + flags: Arc::new(RwLock::new(HashMap::new())), + group_type_mapping: Arc::new(RwLock::new(HashMap::new())), + cohorts: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub fn update(&self, response: LocalEvaluationResponse) { + let mut flags = self.flags.write().unwrap(); + flags.clear(); + for flag in response.flags { + flags.insert(flag.key.clone(), flag); + } + + let mut mapping = self.group_type_mapping.write().unwrap(); + *mapping = response.group_type_mapping; + + let mut cohorts = self.cohorts.write().unwrap(); + *cohorts = response.cohorts; + } + + pub fn get_flag(&self, key: &str) -> Option { + self.flags.read().unwrap().get(key).cloned() + } + + #[cfg(test)] // only for tests uses + fn clear(&self) { + self.flags.write().unwrap().clear(); + self.group_type_mapping.write().unwrap().clear(); + self.cohorts.write().unwrap().clear(); + } +} + +/// Configuration for local evaluation +#[derive(Clone)] +pub(crate) struct LocalEvaluationConfig { + /// Personal API key for authentication (sensitive - transmitted via Authorization header only) + pub personal_api_key: String, + /// Project API key for project identification (public - safe to include in URLs) + /// Note: PostHog project API keys (phc_*) are designed to be public and used in client-side code. + /// See: + pub project_api_key: String, + pub api_host: String, + pub poll_interval: Duration, + pub request_timeout: Duration, +} + +/// Manages polling for feature flag definitions +#[cfg(not(feature = "async-client"))] +pub(crate) struct FlagPoller { + config: LocalEvaluationConfig, + cache: FlagCache, + client: reqwest::blocking::Client, + stop_signal: Arc, + thread_handle: Arc>>>, +} + +#[cfg(not(feature = "async-client"))] +impl FlagPoller { + pub(crate) fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self { + let client = reqwest::blocking::Client::builder() + .timeout(config.request_timeout) + .build() + .unwrap(); + + Self { + config, + cache, + client, + stop_signal: Arc::new(AtomicBool::new(false)), + thread_handle: Arc::new(Mutex::new(None)), + } + } + + /// Start the polling thread + pub(crate) fn start(&self) { + // Initial load (silently ignore errors) + let _ = self.load_flags(); + + let config = self.config.clone(); + let cache = self.cache.clone(); + let stop_signal = self.stop_signal.clone(); + + let handle = std::thread::spawn(move || { + let client = reqwest::blocking::Client::builder() + .timeout(config.request_timeout) + .build() + .unwrap(); + + loop { + std::thread::sleep(config.poll_interval); + + if stop_signal.load(Ordering::Relaxed) { + break; + } + + // Note: project_api_key (phc_*) is public and safe in URLs - see `LocalEvaluationConfig ` struct docs + let url = format!( + "{}{}?token={}&send_cohorts", + config.api_host.trim_end_matches('/'), + Endpoint::LocalEvaluation.path(), + config.project_api_key + ); + + match client + .get(&url) // CODEQL_IGNORE + .header( + "Authorization", + format!("Bearer {}", config.personal_api_key), + ) + .send() + { + Ok(response) => { + if response.status().is_success() { + if let Ok(data) = response.json::() { + cache.update(data); + } + } + } + Err(_e) => {} + } + } + }); + + *self.thread_handle.lock().unwrap() = Some(handle); + } + + /// Load flags synchronously + fn load_flags(&self) -> Result<(), Error> { + // Note: project_api_key (phc_*) is public and safe in URLs - see `LocalEvaluationConfig ` struct docs + let url = format!( + "{}{}?token={}&send_cohorts", + self.config.api_host.trim_end_matches('/'), + Endpoint::LocalEvaluation.path(), + self.config.project_api_key + ); + + let response = self + .client + .get(&url) // CODEQL_IGNORE + .header( + "Authorization", + format!("Bearer {}", self.config.personal_api_key), + ) + .send() + .map_err(TransportError::from)?; + + if !response.status().is_success() { + return Err(TransportError::HttpError( + response.status().as_u16(), + format!("HTTP {}", response.status()), + ) + .into()); + } + + let data = response + .json::() + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; + + self.cache.update(data); + Ok(()) + } + + /// Stop the polling thread + fn stop(&self) { + self.stop_signal.store(true, Ordering::Relaxed); + if let Some(handle) = self.thread_handle.lock().unwrap().take() { + handle.join().ok(); + } + } +} + +#[cfg(not(feature = "async-client"))] +impl Drop for FlagPoller { + fn drop(&mut self) { + self.stop(); + } +} + +/// Async version of the flag poller +#[cfg(feature = "async-client")] +pub(crate) struct AsyncFlagPoller { + config: LocalEvaluationConfig, + cache: FlagCache, + client: reqwest::Client, + stop_signal: Arc, + task_handle: Arc>>>, + is_running: Arc, +} + +#[cfg(feature = "async-client")] +impl AsyncFlagPoller { + pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self { + let client = reqwest::Client::builder() + .timeout(config.request_timeout) + .build() + .unwrap(); + + Self { + config, + cache, + client, + stop_signal: Arc::new(AtomicBool::new(false)), + task_handle: Arc::new(tokio::sync::Mutex::new(None)), + is_running: Arc::new(AtomicBool::new(false)), + } + } + + /// Start the polling task + pub async fn start(&self) { + // Check if already running + if self.is_running.swap(true, Ordering::Relaxed) { + return; // Already running + } + + // Initial load (silently ignore errors) + let _ = self.load_flags().await; + + let config = self.config.clone(); + let cache = self.cache.clone(); + let stop_signal = self.stop_signal.clone(); + let is_running = self.is_running.clone(); + let client = self.client.clone(); + + let task = tokio::spawn(async move { + let mut interval = tokio::time::interval(config.poll_interval); + interval.tick().await; // Skip the first immediate tick + + loop { + tokio::select! { + _ = interval.tick() => { + if stop_signal.load(Ordering::Relaxed) { + break; + } + + // Note: project_api_key (phc_*) is public and safe in URLs - see `LocalEvaluationConfig ` struct docs + let url = format!( + "{}{}?token={}&send_cohorts", + config.api_host.trim_end_matches('/'), + Endpoint::LocalEvaluation.path(), + config.project_api_key + ); + + match client + .get(&url) // CODEQL_IGNORE + .header("Authorization", format!("Bearer {}", config.personal_api_key)) + .send() + .await + { + Ok(response) => { + if response.status().is_success() { + if let Ok(data) = response.json::().await { + cache.update(data); + } + } + } + Err(_e) => {}, + } + } + } + } + + // Clear running flag when task exits + is_running.store(false, Ordering::Relaxed); + }); + + *self.task_handle.lock().await = Some(task); + } + + /// Load flags asynchronously + pub async fn load_flags(&self) -> Result<(), Error> { + // Note: project_api_key (phc_*) is public and safe in URLs - see `LocalEvaluationConfig ` struct docs + let url = format!( + "{}{}?token={}&send_cohorts", + self.config.api_host.trim_end_matches('/'), + Endpoint::LocalEvaluation.path(), + self.config.project_api_key + ); + + let response = self + .client + .get(&url) // CODEQL_IGNORE + .header( + "Authorization", + format!("Bearer {}", self.config.personal_api_key), + ) + .send() + .await + .map_err(TransportError::from)?; + + if !response.status().is_success() { + return Err(TransportError::HttpError( + response.status().as_u16(), + format!("HTTP {}", response.status()), + ) + .into()); + } + + let data = response + .json::() + .await + .map_err(|e| ValidationError::SerializationFailed(e.to_string()))?; + + self.cache.update(data); + Ok(()) + } +} + +#[cfg(feature = "async-client")] +impl Drop for AsyncFlagPoller { + fn drop(&mut self) { + // Set stop signal + self.stop_signal.store(true, Ordering::Relaxed); + + // Abort the task if still running + if let Ok(mut guard) = self.task_handle.try_lock() { + if let Some(handle) = guard.take() { + handle.abort(); + } + } + } +} + +/// Evaluator for locally cached flags +pub(crate) struct LocalEvaluator { + cache: FlagCache, +} + +impl LocalEvaluator { + pub fn new(cache: FlagCache) -> Self { + Self { cache } + } + + /// Evaluate a feature flag locally + pub fn evaluate_flag( + &self, + key: &str, + distinct_id: &str, + person_properties: &HashMap, + ) -> Result, InconclusiveMatchError> { + match self.cache.get_flag(key) { + Some(flag) => match_feature_flag(&flag, distinct_id, person_properties).map(Some), + None => Ok(None), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::feature_flags::{FeatureFlagCondition, FeatureFlagFilters, Property}; + use serde_json::json; + + #[test] + fn test_local_evaluation_basic() { + // Create a cache and evaluator + let cache = FlagCache::new(); + let evaluator = LocalEvaluator::new(cache.clone()); + + // Create a simple flag + let flag = FeatureFlag { + key: "test-flag".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + // Update cache with the flag + let response = LocalEvaluationResponse { + flags: vec![flag], + group_type_mapping: HashMap::new(), + cohorts: HashMap::new(), + }; + cache.update(response); + + // Test evaluation + let properties = HashMap::new(); + let result = evaluator.evaluate_flag("test-flag", "user-123", &properties); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + } + + #[test] + fn test_local_evaluation_with_properties() { + let cache = FlagCache::new(); + let evaluator = LocalEvaluator::new(cache.clone()); + + // Create a flag with property conditions + let flag = FeatureFlag { + key: "premium-feature".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![FeatureFlagCondition { + properties: vec![Property { + key: "plan".to_string(), + value: json!("premium"), + operator: "exact".to_string(), + property_type: None, + }], + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + payloads: HashMap::new(), + }, + }; + + // Update cache + let response = LocalEvaluationResponse { + flags: vec![flag], + group_type_mapping: HashMap::new(), + cohorts: HashMap::new(), + }; + cache.update(response); + + // Test with matching properties + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("premium")); + + let result = evaluator.evaluate_flag("premium-feature", "user-123", &properties); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + // Test with non-matching properties + let mut properties = HashMap::new(); + properties.insert("plan".to_string(), json!("free")); + + let result = evaluator.evaluate_flag("premium-feature", "user-456", &properties); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(false))); + } + + #[test] + fn test_local_evaluation_missing_flag() { + let cache = FlagCache::new(); + let evaluator = LocalEvaluator::new(cache); + + let properties = HashMap::new(); + let result = evaluator.evaluate_flag("non-existent", "user-123", &properties); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), None); + } + + #[test] + fn test_cache_operations() { + let cache = FlagCache::new(); + + // Create multiple flags + let flags = vec![ + FeatureFlag { + key: "flag1".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![], + multivariate: None, + payloads: HashMap::new(), + }, + }, + FeatureFlag { + key: "flag2".to_string(), + active: true, + filters: FeatureFlagFilters { + groups: vec![], + multivariate: None, + payloads: HashMap::new(), + }, + }, + ]; + + let response = LocalEvaluationResponse { + flags: flags.clone(), + group_type_mapping: HashMap::new(), + cohorts: HashMap::new(), + }; + + cache.update(response); + + // Test get_flag + assert!(cache.get_flag("flag1").is_some()); + assert!(cache.get_flag("flag2").is_some()); + assert!(cache.get_flag("flag3").is_none()); + + // Test clear + cache.clear(); + assert!(cache.get_flag("flag1").is_none()); + } +} diff --git a/tests/api_compatibility.rs b/tests/api_compatibility.rs new file mode 100644 index 0000000..ea4b76e --- /dev/null +++ b/tests/api_compatibility.rs @@ -0,0 +1,107 @@ +/// API stability tests +/// These tests will fail to compile if we make breaking changes to the public API. +#[cfg(test)] +mod api_compatibility_tests { + use posthog_rs::*; + + #[test] + fn test_event_constructors_exist() { + let _event1 = Event::new("test", "user123"); + let _event2 = Event::new_anon("test"); + } + + #[test] + fn test_event_methods_exist() { + let mut event = Event::new("test", "user123"); + let _ = event.insert_prop("key", "value"); + event.add_group("group", "id"); + let _ = event.set_timestamp(chrono::Utc::now()); + } + + #[test] + fn test_client_options_builder() { + let _options = ClientOptionsBuilder::default() + .api_key("test".to_string()) + .build(); + } + + #[test] + fn test_client_options_from_str() { + let _options: ClientOptions = "test_api_key".into(); + } + + #[cfg(not(feature = "async-client"))] + #[test] + fn test_blocking_client_methods_exist() { + fn _check_blocking_client(client: &Client) { + let event = Event::new("test", "user"); + let _: Result<(), Error> = client.capture(event); + let _: Result<(), Error> = client.capture_batch(vec![]); + } + } + + #[cfg(feature = "async-client")] + #[test] + fn test_async_client_methods_exist() { + async fn _check_async_client(client: &Client) { + let event = Event::new("test", "user"); + let _: Result<(), Error> = client.capture(event).await; + let _: Result<(), Error> = client.capture_batch(vec![]).await; + } + } + + #[test] + fn test_global_functions_exist() { + #[cfg(feature = "async-client")] + async fn _check_async_global() { + let _init: fn(ClientOptions) -> _ = init_global; + let event = Event::new("test", "user"); + let _: Result<(), Error> = capture(event).await; + let _disable: fn() = disable_global; + let _is_disabled: fn() -> bool = global_is_disabled; + } + + #[cfg(not(feature = "async-client"))] + fn _check_blocking_global() { + let _init: fn(ClientOptions) -> Result<(), Error> = init_global; + let event = Event::new("test", "user"); + let _: Result<(), Error> = capture(event); + let _disable: fn() = disable_global; + let _is_disabled: fn() -> bool = global_is_disabled; + } + } + + #[test] + fn test_error_types_exist() { + // Ensure error types can be constructed and matched + fn _handle_errors() { + let _transport = TransportError::Timeout(std::time::Duration::from_secs(30)); + let _validation = ValidationError::InvalidTimestamp("test".to_string()); + let _init = InitializationError::MissingApiKey; + + let _err1: Error = Error::Transport(TransportError::NetworkUnreachable); + let _err2: Error = Error::Validation(ValidationError::InvalidDistinctId("".into())); + let _err3: Error = Error::Initialization(InitializationError::NotInitialized); + } + } + + #[test] + fn test_error_methods_exist() { + let err = Error::Transport(TransportError::Timeout(std::time::Duration::from_secs(30))); + let _: bool = err.is_retryable(); + let _: bool = err.is_client_error(); + } + + #[test] + fn test_non_exhaustive_pattern_matching() { + let err = Error::Transport(TransportError::Timeout(std::time::Duration::from_secs(1))); + + // Must include catch-all due to #[non_exhaustive] + match err { + Error::Transport(_) => {} + Error::Validation(_) => {} + Error::Initialization(_) => {} + _ => {} + } + } +} diff --git a/tests/backward_compatibility.rs b/tests/backward_compatibility.rs new file mode 100644 index 0000000..73da827 --- /dev/null +++ b/tests/backward_compatibility.rs @@ -0,0 +1,269 @@ +/// Backward compatibility tests for the PostHog Rust SDK. +/// +/// This test suite verifies: +/// - Error retry logic (which errors can be retried) +/// - Error classification (client vs infrastructure errors) +/// - Backward compatibility (deprecated variants still work) +#[cfg(test)] +mod backward_compatibility_tests { + use posthog_rs::*; + use std::time::Duration; + + // ===== Retry Logic Tests ===== + + #[test] + fn test_http_error_retry_logic() { + // Test the actual business logic: which HTTP status codes are retryable + + // 5xx errors should be retryable + for status in 500..600 { + let err = Error::Transport(TransportError::HttpError(status, "".to_string())); + assert!(err.is_retryable(), "HTTP {} should be retryable", status); + } + + // 4xx errors should NOT be retryable, except 429 + for status in 400..500 { + let err = Error::Transport(TransportError::HttpError(status, "".to_string())); + if status == 429 { + assert!( + err.is_retryable(), + "HTTP 429 should be retryable (rate limit)" + ); + } else { + assert!( + !err.is_retryable(), + "HTTP {} should not be retryable", + status + ); + } + } + } + + #[test] + fn test_http_error_client_classification() { + // Test the actual business logic: which HTTP errors are client errors + + // 4xx are client errors (user's fault) + for status in 400..500 { + let err = Error::Transport(TransportError::HttpError(status, "".to_string())); + assert!( + err.is_client_error(), + "HTTP {} should be a client error", + status + ); + } + + // 5xx are NOT client errors (server's fault) + for status in 500..600 { + let err = Error::Transport(TransportError::HttpError(status, "".to_string())); + assert!( + !err.is_client_error(), + "HTTP {} should not be a client error", + status + ); + } + } + + // ===== Real-World Use Case Tests ===== + + #[test] + fn test_retry_strategy_with_backoff() { + // Simulates real-world retry logic with attempt limits + fn should_retry(err: &Error, attempt: u32, max_attempts: u32) -> bool { + err.is_retryable() && attempt < max_attempts + } + + // Retryable errors + let retryable = vec![ + Error::Transport(TransportError::Timeout(Duration::from_secs(30))), + Error::Transport(TransportError::NetworkUnreachable), + Error::Transport(TransportError::HttpError(503, "unavailable".to_string())), + Error::Transport(TransportError::HttpError(429, "rate limit".to_string())), + ]; + + for err in retryable { + assert!(should_retry(&err, 1, 3), "Should retry: {:?}", err); + } + + // Non-retryable errors + let non_retryable = vec![ + Error::Transport(TransportError::HttpError(401, "unauthorized".to_string())), + Error::Transport(TransportError::DnsResolution("bad.host".to_string())), + Error::Validation(ValidationError::InvalidTimestamp("bad".to_string())), + Error::Initialization(InitializationError::MissingApiKey), + ]; + + for err in non_retryable { + assert!(!should_retry(&err, 1, 3), "Should not retry: {:?}", err); + } + } + + #[test] + fn test_error_severity_for_logging() { + // Simulates real-world logging strategy based on error type + fn log_level(err: &Error) -> &str { + match (err.is_client_error(), err.is_retryable()) { + (true, _) => "ERROR", // Client error - user needs to fix + (false, true) => "WARN", // Retryable - transient issue + (false, false) => "ERROR", // Not retryable - permanent issue + } + } + + // Client errors get ERROR level + assert_eq!( + log_level(&Error::Validation(ValidationError::InvalidTimestamp( + "x".into() + ))), + "ERROR" + ); + assert_eq!( + log_level(&Error::Transport(TransportError::HttpError( + 400, + "bad".into() + ))), + "ERROR" + ); + + // Retryable errors get WARN level + assert_eq!( + log_level(&Error::Transport(TransportError::Timeout( + Duration::from_secs(1) + ))), + "WARN" + ); + assert_eq!( + log_level(&Error::Transport(TransportError::HttpError( + 503, + "unavailable".into() + ))), + "WARN" + ); + + // Permanent infrastructure errors get ERROR level + assert_eq!( + log_level(&Error::Transport(TransportError::DnsResolution("x".into()))), + "ERROR" + ); + } + + // ===== Backward Compatibility: Deprecated Variants ===== + + #[allow(deprecated)] + #[test] + fn test_deprecated_errors_with_new_methods() { + // CRITICAL: Ensure deprecated errors work with new is_retryable() and is_client_error() + let deprecated_errors = vec![ + Error::Connection("timeout".to_string()), + Error::Serialization("bad json".to_string()), + Error::AlreadyInitialized, + Error::NotInitialized, + Error::InvalidTimestamp("future".to_string()), + ]; + + for err in deprecated_errors { + // Deprecated errors conservatively return false for both methods + assert!( + !err.is_retryable(), + "Deprecated error should not be retryable by default" + ); + assert!( + !err.is_client_error(), + "Deprecated error should not be client error by default" + ); + } + } + + #[allow(deprecated)] + #[test] + fn test_old_and_new_errors_coexist() { + // CRITICAL: Ensure old and new error types can be handled together + fn categorize(err: Error) -> &'static str { + match err { + Error::Transport(_) => "new_transport", + Error::Validation(_) => "new_validation", + Error::Initialization(_) => "new_initialization", + Error::Connection(_) => "deprecated_connection", + Error::Serialization(_) => "deprecated_serialization", + Error::AlreadyInitialized => "deprecated_already_init", + Error::NotInitialized => "deprecated_not_init", + Error::InvalidTimestamp(_) => "deprecated_timestamp", + _ => "unknown", + } + } + + // New errors work + assert_eq!( + categorize(Error::Transport(TransportError::Timeout( + Duration::from_secs(1) + ))), + "new_transport" + ); + + // Old errors still work + assert_eq!( + categorize(Error::Connection("err".to_string())), + "deprecated_connection" + ); + } + + #[allow(deprecated)] + #[test] + fn test_deprecated_error_construction_and_matching() { + // Verify basic construction and pattern matching still works + + // String variants + let conn_err = Error::Connection("network failure".to_string()); + assert!(matches!(conn_err, Error::Connection(_))); + assert!(conn_err.to_string().contains("Connection error")); + + let serial_err = Error::Serialization("invalid json".to_string()); + assert!(matches!(serial_err, Error::Serialization(_))); + + let ts_err = Error::InvalidTimestamp("future time".to_string()); + assert!(matches!(ts_err, Error::InvalidTimestamp(_))); + + // Unit variants + let already_init = Error::AlreadyInitialized; + assert!(matches!(already_init, Error::AlreadyInitialized)); + + let not_init = Error::NotInitialized; + assert!(matches!(not_init, Error::NotInitialized)); + } + + #[test] + fn test_migration_path_documented() { + // Documents the migration path from old to new error types + // This test doesn't assert anything - it just shows the mapping + + // Old: Error::Connection → New: Error::Transport(TransportError::*) + let _timeout = Error::Transport(TransportError::Timeout(Duration::from_secs(30))); + let _http = Error::Transport(TransportError::HttpError(500, "error".to_string())); + + // Old: Error::Serialization → New: Error::Validation(ValidationError::SerializationFailed) + let _serial = Error::Validation(ValidationError::SerializationFailed("err".to_string())); + + // Old: Error::AlreadyInitialized → New: Error::Initialization(InitializationError::AlreadyInitialized) + let _already_init = Error::Initialization(InitializationError::AlreadyInitialized); + + // Old: Error::NotInitialized → New: Error::Initialization(InitializationError::NotInitialized) + let _not_init = Error::Initialization(InitializationError::NotInitialized); + + // Old: Error::InvalidTimestamp → New: Error::Validation(ValidationError::InvalidTimestamp) + let _timestamp = Error::Validation(ValidationError::InvalidTimestamp("err".to_string())); + } + + // ===== Non-Exhaustive Enum Behavior ===== + + #[test] + fn test_non_exhaustive_requires_catch_all() { + // Verifies #[non_exhaustive] works correctly - users must include catch-all + let err = Error::Transport(TransportError::Timeout(Duration::from_secs(1))); + + match err { + Error::Transport(_) => {} + Error::Validation(_) => {} + Error::Initialization(_) => {} + _ => {} // This is required due to #[non_exhaustive] + } + } +} diff --git a/tests/test.rs b/tests/test.rs index 4d27cca..a14511b 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,31 @@ -#[cfg(feature = "e2e-test")] +#[cfg(all(feature = "e2e-test", feature = "async-client"))] +#[tokio::test] +async fn get_client_async() { + use dotenv::dotenv; + dotenv().ok(); // Load the .env file + println!("Loaded .env for tests"); + + // see https://us.posthog.com/project/115809/ for the e2e project + use posthog_rs::Event; + use std::collections::HashMap; + + let api_key = std::env::var("POSTHOG_RS_E2E_TEST_API_KEY").unwrap(); + let client = posthog_rs::client(api_key.as_str()).await; + + let mut child_map = HashMap::new(); + child_map.insert("child_key1", "child_value1"); + + let mut event = Event::new("e2e test event", "1234"); + event.insert_prop("key1", "value1").unwrap(); + event.insert_prop("key2", vec!["a", "b"]).unwrap(); + event.insert_prop("key3", child_map).unwrap(); + + client.capture(event).await.unwrap(); +} + +#[cfg(all(feature = "e2e-test", not(feature = "async-client")))] #[test] -fn get_client() { +fn get_client_blocking() { use dotenv::dotenv; dotenv().ok(); // Load the .env file println!("Loaded .env for tests"); @@ -22,3 +47,138 @@ fn get_client() { client.capture(event).unwrap(); } + +// E2E Test for Feature Flag Events +#[cfg(all(feature = "e2e-test", feature = "async-client"))] +#[tokio::test] +async fn test_feature_flag_events_e2e_async() { + use dotenv::dotenv; + use std::collections::HashMap; + dotenv().ok(); + + let api_key = std::env::var("POSTHOG_RS_E2E_TEST_API_KEY").unwrap(); + let client = posthog_rs::client(api_key.as_str()).await; + + println!("Testing feature flag evaluation with real API..."); + + // Test 1: Evaluate a feature flag (this should create a $feature_flag_called event) + let mut properties = HashMap::new(); + properties.insert("email".to_string(), serde_json::json!("test@example.com")); + + let result = client + .get_feature_flag( + "test-flag", + "e2e-test-user-123", + None, + Some(properties.clone()), + None, + ) + .await; + + println!("Feature flag result: {:?}", result); + assert!(result.is_ok()); + + // Test 2: Test is_feature_enabled (should also create event) + let enabled_result = client + .is_feature_enabled( + "test-flag", + "e2e-test-user-456", + None, + Some(properties.clone()), + None, + ) + .await; + + println!("Feature enabled result: {:?}", enabled_result); + assert!(enabled_result.is_ok()); + + // Test 3: Check same flag again for same user - should NOT create new event (deduplication) + println!("\nTesting deduplication - calling same flag for user-123 again..."); + let duplicate_result = client + .get_feature_flag( + "test-flag", + "e2e-test-user-123", + None, + Some(properties.clone()), + None, + ) + .await; + + println!("Duplicate call result: {:?}", duplicate_result); + println!("Note: This should NOT create a new event (same user + flag + response)"); + + // Give time for events to be sent + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + println!("\n=== E2E Test Summary ==="); + println!("✅ Total flag evaluations: 3"); + println!("✅ Expected events in PostHog: 2"); + println!(" - Event 1: user-123 + test-flag + false"); + println!(" - Event 2: user-456 + test-flag + false"); + println!(" - Event 3: DEDUPLICATED (same as Event 1)"); + println!("\nCheck PostHog for exactly 2 new $feature_flag_called events."); +} + +#[cfg(all(feature = "e2e-test", not(feature = "async-client")))] +#[test] +fn test_feature_flag_events_e2e_blocking() { + use dotenv::dotenv; + use std::collections::HashMap; + dotenv().ok(); + + let api_key = std::env::var("POSTHOG_RS_E2E_TEST_API_KEY").unwrap(); + let client = posthog_rs::client(api_key.as_str()); + + println!("Testing feature flag evaluation with real API (blocking)..."); + + // Test 1: Evaluate a feature flag (this should create a $feature_flag_called event) + let mut properties = HashMap::new(); + properties.insert("email".to_string(), serde_json::json!("test@example.com")); + + let result = client.get_feature_flag( + "test-flag".to_string(), + "e2e-test-user-123".to_string(), + None, + Some(properties.clone()), + None, + ); + + println!("Feature flag result: {:?}", result); + assert!(result.is_ok()); + + // Test 2: Test is_feature_enabled (should also create event) + let enabled_result = client.is_feature_enabled( + "test-flag".to_string(), + "e2e-test-user-456".to_string(), + None, + Some(properties.clone()), + None, + ); + + println!("Feature enabled result: {:?}", enabled_result); + assert!(enabled_result.is_ok()); + + // Test 3: Check same flag again for same user - should NOT create new event (deduplication) + println!("\nTesting deduplication - calling same flag for user-123 again..."); + let duplicate_result = client.get_feature_flag( + "test-flag".to_string(), + "e2e-test-user-123".to_string(), + None, + Some(properties.clone()), + None, + ); + + println!("Duplicate call result: {:?}", duplicate_result); + println!("Note: This should NOT create a new event (same user + flag + response)"); + + // Give time for events to be sent + std::thread::sleep(std::time::Duration::from_secs(2)); + + println!("\n=== E2E Test Summary ==="); + println!("✅ Total flag evaluations: 3"); + println!("✅ Expected events in PostHog: 2"); + println!(" - Event 1: user-123 + test-flag + false"); + println!(" - Event 2: user-456 + test-flag + false"); + println!(" - Event 3: DEDUPLICATED (same as Event 1)"); + println!("\nCheck PostHog for exactly 2 new $feature_flag_called events."); +} diff --git a/tests/test_async.rs b/tests/test_async.rs new file mode 100644 index 0000000..592e937 --- /dev/null +++ b/tests/test_async.rs @@ -0,0 +1,710 @@ +#![cfg(feature = "async-client")] + +use httpmock::prelude::*; +use posthog_rs::FlagValue; +use serde_json::json; +use std::collections::HashMap; + +async fn create_test_client(base_url: String) -> posthog_rs::Client { + // Use the From implementation to ensure endpoint_manager is set up correctly + let options: posthog_rs::ClientOptions = (("test_api_key", base_url.as_str())).into(); + posthog_rs::client(options).await +} + +#[tokio::test] +async fn test_get_all_feature_flags() { + let server = MockServer::start(); + + let mock_response = json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + }, + "disabled-flag": { + "key": "disabled-flag", + "enabled": false, + "variant": null + }, + "variant-flag": { + "key": "variant-flag", + "enabled": true, + "variant": "control", + "metadata": { + "id": 1, + "version": 1, + "payload": "{\"color\": \"blue\", \"size\": \"large\"}" + } + } + } + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user" + })); + then.status(200) + .header("content-type", "application/json") + .json_body(mock_response); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_all_flags_and_payloads("test-user".to_string(), None, None, None) + .await; + + if let Err(e) = &result { + eprintln!("Error: {:?}", e); + eprintln!("Mock server URL: {}", server.base_url()); + } + assert!(result.is_ok()); + let (feature_flags, payloads) = result.unwrap(); + + assert_eq!( + feature_flags.get("test-flag"), + Some(&FlagValue::Boolean(true)) + ); + assert_eq!( + feature_flags.get("disabled-flag"), + Some(&FlagValue::Boolean(false)) + ); + assert_eq!( + feature_flags.get("variant-flag"), + Some(&FlagValue::String("control".to_string())) + ); + + assert!(payloads.contains_key("variant-flag")); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_is_feature_enabled() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "enabled-flag": { + "key": "enabled-flag", + "enabled": true, + "variant": null + }, + "disabled-flag": { + "key": "disabled-flag", + "enabled": false, + "variant": null + } + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let enabled_result = client + .is_feature_enabled( + "enabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + let disabled_result = client + .is_feature_enabled( + "disabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(disabled_result.is_ok()); + assert_eq!(disabled_result.unwrap(), false); + + flags_mock.assert_hits(2); +} + +#[tokio::test] +async fn test_get_feature_flag_with_properties() { + let server = MockServer::start(); + + let person_properties = json!({ + "country": "US", + "age": 25, + "plan": "premium" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user", + "person_properties": person_properties + })); + then.status(200).json_body(json!({ + "flags": { + "premium-feature": { + "key": "premium-feature", + "enabled": true, + "variant": null + } + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let mut props = HashMap::new(); + props.insert("country".to_string(), json!("US")); + props.insert("age".to_string(), json!(25)); + props.insert("plan".to_string(), json!("premium")); + + let result = client + .get_feature_flag( + "premium-feature".to_string(), + "test-user".to_string(), + None, + Some(props), + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_multivariate_flag() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "experiment": { + "key": "experiment", + "enabled": true, + "variant": "variant-b" + } + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(FlagValue::String("variant-b".to_string())) + ); + + let enabled_result = client + .is_feature_enabled( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + flags_mock.assert_hits(2); +} + +#[tokio::test] +async fn test_api_error_handling() { + let server = MockServer::start(); + + let error_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(500).body("Internal Server Error"); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flags("test-user".to_string(), None, None, None) + .await; + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("500")); + + error_mock.assert(); +} + +#[tokio::test] +async fn test_get_feature_flag_payload() { + let server = MockServer::start(); + + let payload_data = json!({ + "steps": ["welcome", "profile", "preferences"], + "theme": "dark" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "onboarding-flow": { + "key": "onboarding-flow", + "enabled": true, + "variant": "variant-a", + "metadata": { + "id": 1, + "version": 1, + "payload": payload_data + } + } + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_all_flags_and_payloads("test-user".to_string(), None, None, None) + .await; + + assert!(result.is_ok()); + let (_flags, payloads) = result.unwrap(); + let payload = payloads.get("onboarding-flow"); + assert!(payload.is_some()); + + let payload_value = payload.unwrap(); + assert_eq!(payload_value["theme"], "dark"); + assert!(payload_value["steps"].is_array()); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_nonexistent_flag() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": {} + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag( + "nonexistent-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), None); + + let enabled_result = client + .is_feature_enabled( + "nonexistent-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ) + .await; + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), false); + + flags_mock.assert_hits(2); +} + +#[tokio::test] +async fn test_empty_distinct_id() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "" + })); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag("test-flag".to_string(), "".to_string(), None, None, None) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_groups_parameter() { + let server = MockServer::start(); + + let groups_json = json!({ + "company": "acme-corp", + "team": "engineering" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user", + "groups": groups_json + })); + then.status(200).json_body(json!({ + "flags": { + "team-feature": { + "key": "team-feature", + "enabled": true, + "variant": null + } + } + })); + }); + + let client = create_test_client(server.base_url()).await; + + let mut groups = HashMap::new(); + groups.insert("company".to_string(), "acme-corp".to_string()); + groups.insert("team".to_string(), "engineering".to_string()); + + let result = client + .get_feature_flag( + "team-feature".to_string(), + "test-user".to_string(), + Some(groups), + None, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[tokio::test] +async fn test_malformed_response() { + let server = MockServer::start(); + + let malformed_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).body("not json"); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flags("test-user".to_string(), None, None, None) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("expected")); + + malformed_mock.assert(); +} + +// Feature Flag Event Tests + +#[tokio::test] +async fn test_feature_flag_event_captured() { + // Test that $feature_flag_called event is captured when calling get_feature_flag + let server = MockServer::start(); + + // Mock the flags endpoint + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + // Mock the capture endpoint to verify event is sent + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called" + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag("test-flag", "test-user", None, None, None) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + flags_mock.assert(); + capture_mock.assert(); +} + +#[tokio::test] +async fn test_feature_flag_event_deduplication() { + // Test that calling same flag for same user doesn't send duplicate events + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/"); + then.status(200); + }); + + let client = create_test_client(server.base_url()).await; + + // First call - should capture event + client + .get_feature_flag("test-flag", "test-user", None, None, None) + .await + .ok(); + + // Second call - should NOT capture event (deduplication) + client + .get_feature_flag("test-flag", "test-user", None, None, None) + .await + .ok(); + + flags_mock.assert_hits(2); + capture_mock.assert_hits(1); // Only 1 event captured, not 2 +} + +#[tokio::test] +async fn test_feature_flag_event_different_user() { + // Test that calling same flag for different user captures new event + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called" + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()).await; + + // Call for user1 - should capture event + client + .get_feature_flag("test-flag", "user1", None, None, None) + .await + .ok(); + + // Call for user2 - should capture event (different user) + client + .get_feature_flag("test-flag", "user2", None, None, None) + .await + .ok(); + + flags_mock.assert_hits(2); + capture_mock.assert_hits(2); // 2 events captured for different users +} + +#[tokio::test] +async fn test_feature_flag_event_send_false() { + // Test that send_feature_flag_events=false disables event capture + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/capture/"); + then.status(200); + }); + + // Create client with send_feature_flag_events disabled + let options = posthog_rs::ClientOptionsBuilder::default() + .api_key("test_api_key".to_string()) + .host(server.base_url()) + .send_feature_flag_events(false) + .build() + .unwrap(); + + let client = posthog_rs::client(options).await; + + let result = client + .get_feature_flag("test-flag", "test-user", None, None, None) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + flags_mock.assert(); + capture_mock.assert_hits(0); // No event captured when disabled +} + +#[tokio::test] +async fn test_feature_flag_event_with_variant() { + // Test that multivariate flag variant is captured in event + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "variant-flag": { + "key": "variant-flag", + "enabled": true, + "variant": "control" + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/"); + then.status(200); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .get_feature_flag("variant-flag", "test-user", None, None, None) + .await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(FlagValue::String("control".to_string())) + ); + flags_mock.assert(); + capture_mock.assert(); +} + +#[tokio::test] +async fn test_is_feature_enabled_captures_event() { + // Test that is_feature_enabled also captures events + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "enabled-flag": { + "key": "enabled-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called" + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()).await; + + let result = client + .is_feature_enabled("enabled-flag", "test-user", None, None, None) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), true); + flags_mock.assert(); + capture_mock.assert(); +} diff --git a/tests/test_blocking.rs b/tests/test_blocking.rs new file mode 100644 index 0000000..bf8443d --- /dev/null +++ b/tests/test_blocking.rs @@ -0,0 +1,502 @@ +#![cfg(not(feature = "async-client"))] + +use httpmock::prelude::*; +use posthog_rs::FlagValue; +use serde_json::json; +use std::collections::HashMap; + +fn create_test_client(base_url: String) -> posthog_rs::Client { + // Use the From implementation to ensure endpoint_manager is set up correctly + let options: posthog_rs::ClientOptions = (("test_api_key", base_url.as_str())).into(); + posthog_rs::client(options) +} + +#[test] +fn test_get_all_feature_flags() { + let server = MockServer::start(); + + let mock_response = json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + }, + "disabled-flag": { + "key": "disabled-flag", + "enabled": false, + "variant": null + }, + "variant-flag": { + "key": "variant-flag", + "enabled": true, + "variant": "control", + "metadata": { + "id": 1, + "version": 1, + "payload": "{\"color\": \"blue\", \"size\": \"large\"}" + } + } + } + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user" + })); + then.status(200) + .header("content-type", "application/json") + .json_body(mock_response); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_all_flags_and_payloads("test-user".to_string(), None, None, None); + + assert!(result.is_ok()); + let (feature_flags, payloads) = result.unwrap(); + + assert_eq!( + feature_flags.get("test-flag"), + Some(&FlagValue::Boolean(true)) + ); + assert_eq!( + feature_flags.get("disabled-flag"), + Some(&FlagValue::Boolean(false)) + ); + assert_eq!( + feature_flags.get("variant-flag"), + Some(&FlagValue::String("control".to_string())) + ); + + assert!(payloads.contains_key("variant-flag")); + + flags_mock.assert(); +} + +#[test] +fn test_is_feature_enabled() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "enabled-flag": { + "key": "enabled-flag", + "enabled": true, + "variant": null + }, + "disabled-flag": { + "key": "disabled-flag", + "enabled": false, + "variant": null + } + } + })); + }); + + let client = create_test_client(server.base_url()); + + let enabled_result = client.is_feature_enabled( + "enabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + let disabled_result = client.is_feature_enabled( + "disabled-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(disabled_result.is_ok()); + assert_eq!(disabled_result.unwrap(), false); + + flags_mock.assert_hits(2); +} + +#[test] +fn test_get_feature_flag_with_properties() { + let server = MockServer::start(); + + let person_properties = json!({ + "country": "US", + "age": 25, + "plan": "premium" + }); + + let flags_mock = server.mock(|when, then| { + when.method(POST) + .path("/flags/") + .query_param("v", "2") + .json_body(json!({ + "api_key": "test_api_key", + "distinct_id": "test-user", + "person_properties": person_properties + })); + then.status(200).json_body(json!({ + "flags": { + "premium-feature": { + "key": "premium-feature", + "enabled": true, + "variant": null + } + } + })); + }); + + let client = create_test_client(server.base_url()); + + let mut props = HashMap::new(); + props.insert("country".to_string(), json!("US")); + props.insert("age".to_string(), json!(25)); + props.insert("plan".to_string(), json!("premium")); + + let result = client.get_feature_flag( + "premium-feature".to_string(), + "test-user".to_string(), + None, + Some(props), + None, + ); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + + flags_mock.assert(); +} + +#[test] +fn test_multivariate_flag() { + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "experiment": { + "key": "experiment", + "enabled": true, + "variant": "variant-b" + } + } + })); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flag( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(FlagValue::String("variant-b".to_string())) + ); + + let enabled_result = client.is_feature_enabled( + "experiment".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(enabled_result.is_ok()); + assert_eq!(enabled_result.unwrap(), true); + + flags_mock.assert_hits(2); +} + +#[test] +fn test_api_error_handling() { + let server = MockServer::start(); + + let error_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(500).body("Internal Server Error"); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flags("test-user".to_string(), None, None, None); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("500")); + + error_mock.assert(); +} + +// Feature Flag Event Tests + +#[test] +fn test_feature_flag_event_captured() { + // Test that $feature_flag_called event is captured when calling get_feature_flag + let server = MockServer::start(); + + // Mock the flags endpoint + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + // Mock the capture endpoint to verify event is sent + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called", + "properties": { + "$feature_flag": "test-flag", + "$feature_flag_response": true + } + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flag( + "test-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(FlagValue::Boolean(true))); + flags_mock.assert(); + capture_mock.assert(); +} + +#[test] +fn test_feature_flag_event_deduplication() { + // Test that calling same flag for same user doesn't send duplicate events + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called" + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()); + + // First call - should capture event + let _ = client.get_feature_flag( + "test-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + // Second call - should NOT capture event (deduplication) + let _ = client.get_feature_flag( + "test-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + flags_mock.assert_hits(2); + capture_mock.assert_hits(1); // Only 1 event captured, not 2 +} + +#[test] +fn test_feature_flag_event_different_user() { + // Test that calling same flag for different user captures new event + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called", + "properties": { + "$feature_flag": "test-flag", + "$feature_flag_response": true + } + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()); + + // Call for user1 - should capture event + let _ = client.get_feature_flag( + "test-flag".to_string(), + "user1".to_string(), + None, + None, + None, + ); + + // Call for user2 - should capture event (different user) + let _ = client.get_feature_flag( + "test-flag".to_string(), + "user2".to_string(), + None, + None, + None, + ); + + flags_mock.assert_hits(2); + capture_mock.assert_hits(2); // 2 events captured for different users +} + +#[test] +fn test_feature_flag_event_with_variant() { + // Test that multivariate flags capture the variant value correctly + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": "variant-b" + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called", + "properties": { + "$feature_flag": "test-flag", + "$feature_flag_response": "variant-b", + "$feature/test-flag": "variant-b" + } + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()); + + let result = client.get_feature_flag( + "test-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(FlagValue::String("variant-b".to_string())) + ); + flags_mock.assert(); + capture_mock.assert(); +} + +#[test] +fn test_is_feature_enabled_captures_event() { + // Test that is_feature_enabled also captures $feature_flag_called event + let server = MockServer::start(); + + let flags_mock = server.mock(|when, then| { + when.method(POST).path("/flags/").query_param("v", "2"); + then.status(200).json_body(json!({ + "flags": { + "test-flag": { + "key": "test-flag", + "enabled": true, + "variant": null + } + } + })); + }); + + let capture_mock = server.mock(|when, then| { + when.method(POST).path("/i/v0/e/").json_body_partial( + json!({ + "event": "$feature_flag_called" + }) + .to_string(), + ); + then.status(200); + }); + + let client = create_test_client(server.base_url()); + + let result = client.is_feature_enabled( + "test-flag".to_string(), + "test-user".to_string(), + None, + None, + None, + ); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), true); + flags_mock.assert(); + capture_mock.assert(); +} diff --git a/tests/test_local_evaluation.rs b/tests/test_local_evaluation.rs new file mode 100644 index 0000000..2a58159 --- /dev/null +++ b/tests/test_local_evaluation.rs @@ -0,0 +1,97 @@ +#[cfg(feature = "async-client")] +use httpmock::prelude::*; +#[cfg(feature = "async-client")] +use posthog_rs::ClientOptionsBuilder; +#[cfg(feature = "async-client")] +use serde_json::json; +#[cfg(feature = "async-client")] +use std::collections::HashMap; +#[cfg(feature = "async-client")] +use std::time::Duration; + +#[cfg(feature = "async-client")] +#[tokio::test] +async fn test_local_evaluation_with_mock_server() { + let server = MockServer::start(); + + // Mock the local evaluation endpoint + let mock_flags = json!({ + "flags": [ + { + "key": "feature-a", + "active": true, + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 50.0, + "variant": null + } + ], + "multivariate": null, + "payloads": {} + } + }, + { + "key": "feature-b", + "active": true, + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "@company.com", + "operator": "icontains" + } + ], + "rollout_percentage": 100.0, + "variant": null + } + ], + "multivariate": null, + "payloads": {} + } + } + ], + "group_type_mapping": {}, + "cohorts": {} + }); + + let eval_mock = server.mock(|when, then| { + when.method(GET) + .path("/api/feature_flag/local_evaluation/") + .header("Authorization", "Bearer test_personal_key") + .query_param("token", "test_project_key") + .query_param("send_cohorts", ""); + then.status(200).json_body(mock_flags); + }); + + // Create client with local evaluation enabled + let options = ClientOptionsBuilder::default() + .host(server.base_url()) + .api_key("test_project_key".to_string()) + .personal_api_key("test_personal_key".to_string()) + .enable_local_evaluation(true) + .poll_interval_seconds(60) + .build() + .unwrap(); + + let client = posthog_rs::client(options).await; + + // Give it a moment to load initial flags + tokio::time::sleep(Duration::from_millis(100)).await; + + // Test local evaluation + let mut properties = HashMap::new(); + properties.insert("email".to_string(), json!("test@company.com")); + + let result = client + .get_feature_flag("feature-b", "user-123", None, Some(properties), None) + .await; + + assert!(result.is_ok()); + // The actual result depends on whether the mock was hit and processed + + eval_mock.assert(); +}