diff --git a/lazer/Cargo.lock b/lazer/Cargo.lock index eef110ff36..3d5cea2731 100644 --- a/lazer/Cargo.lock +++ b/lazer/Cargo.lock @@ -1966,6 +1966,21 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2866,6 +2881,12 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -3048,6 +3069,23 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "native-tls" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.26.4" @@ -3280,12 +3318,50 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl" +version = "0.10.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5e534d133a060a3c19daec1eb3e98ec6f4685978834f2dbadfe2ec215bab64e" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "opentelemetry" version = "0.17.0" @@ -3682,6 +3758,23 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "pyth-lazer-consumer" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures-util", + "http", + "pyth-lazer-protocol 0.4.0", + "rand 0.8.5", + "serde", + "serde_json", + "tokio", + "tokio-tungstenite", + "tracing", + "ttl_cache", +] + [[package]] name = "pyth-lazer-protocol" version = "0.1.3" @@ -6485,6 +6578,16 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -6530,8 +6633,10 @@ checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" dependencies = [ "futures-util", "log", + "native-tls", "rustls", "tokio", + "tokio-native-tls", "tokio-rustls", "tungstenite", "webpki-roots 0.25.4", @@ -6688,6 +6793,15 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttl_cache" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "tungstenite" version = "0.20.1" @@ -6700,6 +6814,7 @@ dependencies = [ "http", "httparse", "log", + "native-tls", "rand 0.8.5", "rustls", "sha1", @@ -6860,6 +6975,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "vec_map" version = "0.8.2" diff --git a/lazer/Cargo.toml b/lazer/Cargo.toml index 1e83d0a60f..4f5f60c427 100644 --- a/lazer/Cargo.toml +++ b/lazer/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ "sdk/rust/protocol", + "sdk/rust/consumer", "contracts/solana/programs/pyth-lazer-solana-contract", ] diff --git a/lazer/sdk/rust/consumer/Cargo.toml b/lazer/sdk/rust/consumer/Cargo.toml new file mode 100644 index 0000000000..0312458197 --- /dev/null +++ b/lazer/sdk/rust/consumer/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "pyth-lazer-consumer" +version = "0.1.0" +edition = "2021" +description = "Rust consumer SDK for Pyth Lazer" +license = "Apache-2.0" + +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "sync"] } +tokio-tungstenite = { version = "0.20", features = ["native-tls"] } +futures-util = "0.3" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +anyhow = "1.0" +tracing = "0.1" +http = "0.2" +rand = { version = "0.8", features = ["std"] } +ttl_cache = "0.5" +pyth-lazer-protocol = { path = "../protocol" } diff --git a/lazer/sdk/rust/consumer/examples/basic.rs b/lazer/sdk/rust/consumer/examples/basic.rs new file mode 100644 index 0000000000..4ac168f6f1 --- /dev/null +++ b/lazer/sdk/rust/consumer/examples/basic.rs @@ -0,0 +1,49 @@ +use { + anyhow::Result, + pyth_lazer_consumer::{ + Chain, DeliveryFormat, PriceFeedId, PriceFeedProperty, PythLazerConsumer, Response, + }, +}; + +#[tokio::main] +async fn main() -> Result<()> { + let mut consumer = PythLazerConsumer::new( + vec!["wss://endpoint.pyth.network".to_string()], + "your_token_here".to_string(), + ) + .await?; + + // Connect to the WebSocket server + consumer.connect().await?; + + // Subscribe to some price feeds + consumer + .subscribe( + 1, // subscription_id + vec![PriceFeedId(1)], + Some(vec![PriceFeedProperty::Price, PriceFeedProperty::Exponent]), + Some(vec![Chain::Evm]), + Some(DeliveryFormat::Json), + ) + .await?; + + // Receive updates + let mut rx = consumer.subscribe_to_updates(); + while let Ok(update) = rx.recv().await { + if let Response::StreamUpdated(update) = update { + println!( + "Received update for subscription {}", + update.subscription_id.0 + ); + if let Some(parsed) = update.payload.parsed { + for feed in parsed.price_feeds { + println!(" Feed ID: {:?}", feed.price_feed_id); + println!(" Price: {:?}", feed.price); + println!(" Exponent: {:?}", feed.exponent); + } + } + } + } + + Ok(()) +} diff --git a/lazer/sdk/rust/consumer/src/client.rs b/lazer/sdk/rust/consumer/src/client.rs new file mode 100644 index 0000000000..db1670ff5b --- /dev/null +++ b/lazer/sdk/rust/consumer/src/client.rs @@ -0,0 +1,500 @@ +use { + anyhow::{anyhow, Result}, + futures_util::{SinkExt, StreamExt}, + pyth_lazer_protocol::{ + router::{ + Chain, Channel, DeliveryFormat, FixedRate, PriceFeedId, PriceFeedProperty, + SubscriptionParams, SubscriptionParamsRepr, + }, + subscription::{Request, Response, SubscribeRequest, SubscriptionId, UnsubscribeRequest}, + }, + std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, + }, + tokio::{ + net::TcpStream, + sync::{broadcast, Mutex}, + }, + tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream, + }, + tracing::{debug, error, info, warn}, + ttl_cache::TtlCache, +}; + +const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5); +const DEFAULT_NUM_CONNECTIONS: usize = 3; +const MAX_NUM_CONNECTIONS: usize = 50; +const STREAM_POOL_CHANNEL_SIZE: usize = 100_000; +const TICKER_STREAM_CHANNEL_SIZE: usize = 100_000; +const DEDUP_CACHE_SIZE: usize = 100_000; +const DEDUP_TTL: Duration = Duration::from_secs(10); +const RECONNECT_WAIT: Duration = Duration::from_secs(1); +const RECONNECT_STAGGER: Duration = Duration::from_secs(1); + +type WsStream = WebSocketStream>; +type WsSink = futures_util::stream::SplitSink; +type WsSource = futures_util::stream::SplitStream; + +pub struct ConnectionState { + pub id: usize, + pub url: String, + pub(crate) write: Option, + pub(crate) last_message: Option, + pub(crate) healthy: bool, + pub(crate) error_count: usize, +} + +impl ConnectionState { + pub fn new(id: usize, url: String) -> Self { + Self { + id, + url, + write: None, + last_message: None, + healthy: false, + error_count: 0, + } + } + + fn mark_healthy(&mut self) { + self.healthy = true; + self.error_count = 0; + self.last_message = Some(Instant::now()); + } + + fn mark_error(&mut self) { + self.healthy = false; + self.error_count += 1; + self.write = None; + } +} + +impl Clone for ConnectionState { + fn clone(&self) -> Self { + Self { + id: self.id, + url: self.url.clone(), + write: None, + last_message: self.last_message, + healthy: self.healthy, + error_count: self.error_count, + } + } +} + +#[derive(Clone)] +pub struct PythLazerConsumer { + urls: Vec, + token: String, + active_subscriptions: Arc>>, + tx: broadcast::Sender, + stream_tx: broadcast::Sender, + connections: Arc>>>>, + pub(crate) message_cache: Arc>>, + reconnect_attempts: Arc>>, +} + +impl PythLazerConsumer { + pub fn get_tx(&self) -> &broadcast::Sender { + &self.tx + } + + pub async fn process_message( + &self, + message: Message, + cache: &Arc>>, + ) -> Result<()> { + if let Message::Text(text) = message { + if let Ok(response) = serde_json::from_str::(&text) { + let msg_key = format!("{:?}", &response); + let mut cache = cache.lock().await; + + if cache.contains_key(&msg_key) { + return Ok(()); + } + + cache.insert(msg_key, (), DEDUP_TTL); + + if let Err(e) = self.tx.send(response) { + return Err(anyhow!("Failed to forward message: {}", e)); + } + } + } + Ok(()) + } +} + +impl PythLazerConsumer { + pub async fn new(urls: Vec, token: String) -> Result { + let (tx, _) = broadcast::channel(TICKER_STREAM_CHANNEL_SIZE); + let (stream_tx, _) = broadcast::channel(STREAM_POOL_CHANNEL_SIZE); + let num_connections = urls + .len() + .clamp(DEFAULT_NUM_CONNECTIONS, MAX_NUM_CONNECTIONS); + + let mut connections = Vec::with_capacity(num_connections); + for i in 0..num_connections { + let url = urls[i % urls.len()].clone(); + let connection = ConnectionState::new(i, url); + connections.push(Arc::new(Mutex::new(connection))); + } + + let consumer = Self { + urls, + token, + active_subscriptions: Arc::new(Mutex::new(HashMap::new())), + tx, + stream_tx, + connections: Arc::new(Mutex::new(connections)), + message_cache: Arc::new(Mutex::new(TtlCache::new(DEDUP_CACHE_SIZE))), + reconnect_attempts: Arc::new(Mutex::new(HashMap::new())), + }; + + Ok(consumer) + } + + fn exponential_backoff(attempts: usize) -> Duration { + use rand::Rng; + + const BASE_DELAY: u64 = 100; // 100ms + const MAX_DELAY: u64 = 30_000; // 30s + + let base_delay = (2u64.pow(attempts as u32) * BASE_DELAY).min(MAX_DELAY); + let jitter = rand::thread_rng().gen_range(0..=(base_delay / 10)); // 10% jitter + Duration::from_millis(base_delay.saturating_add(jitter)) + } + + async fn connect_with_backoff(&self, connection: &mut ConnectionState) -> Result<()> { + loop { + let attempt_count = { + let mut attempts_guard = self.reconnect_attempts.lock().await; + *attempts_guard.entry(connection.id).or_insert(0) + }; + + match self.connect_single(connection).await { + Ok(_) => { + debug!("Connection {} established successfully", connection.id); + let mut attempts_guard = self.reconnect_attempts.lock().await; + *attempts_guard.entry(connection.id).or_insert(0) = 0; + return Ok(()); + } + Err(e) => { + error!("Connection {} failed: {}", connection.id, e); + { + let mut attempts_guard = self.reconnect_attempts.lock().await; + *attempts_guard.entry(connection.id).or_insert(0) += 1; + } + + let backoff = Self::exponential_backoff(attempt_count + 1); + warn!( + "Connection {} backing off for {:?} (attempt {})", + connection.id, + backoff, + attempt_count + 1 + ); + + tokio::time::sleep(backoff).await; + + // Reset attempts if we've tried too many times + if attempt_count >= 10 { + let mut attempts_guard = self.reconnect_attempts.lock().await; + warn!( + "Connection {} resetting attempt counter after {} attempts", + connection.id, attempt_count + ); + *attempts_guard.entry(connection.id).or_insert(0) = 0; + } + } + } + } + } + + async fn connect_single(&self, connection: &mut ConnectionState) -> Result<()> { + debug!( + "Attempting to connect to {} (connection {})", + connection.url, connection.id + ); + + let request = http::Request::builder() + .uri(&connection.url) + .header("Authorization", format!("Bearer {}", self.token)) + .body(())?; + + let (ws_stream, _) = + match tokio::time::timeout(CONNECTION_TIMEOUT, connect_async(request)).await { + Ok(Ok(result)) => result, + Ok(Err(e)) => { + connection.mark_error(); + return Err(anyhow!("WebSocket connection failed: {}", e)); + } + Err(_) => { + connection.mark_error(); + return Err(anyhow!( + "Connection timed out after {:?}", + CONNECTION_TIMEOUT + )); + } + }; + + let (write, read) = ws_stream.split(); + connection.write = Some(write); + connection.mark_healthy(); + + // Set up message handling + let tx = self.tx.clone(); + let cache = self.message_cache.clone(); + let connection_id = connection.id; + let message_handler = tokio::spawn(async move { + if let Err(e) = Self::handle_messages(read, tx, cache).await { + error!( + "Message handler for connection {} failed: {}", + connection_id, e + ); + } + }); + + // Resubscribe to active subscriptions + if let Some(write) = &mut connection.write { + let subscriptions = self.active_subscriptions.lock().await; + for request in subscriptions.values() { + debug!( + "Resubscribing to feed {} on connection {}", + request.subscription_id.0, connection.id + ); + let msg = serde_json::to_string(&Request::Subscribe(request.clone()))?; + if let Err(e) = write.send(Message::Text(msg)).await { + error!( + "Failed to resubscribe on connection {}: {}", + connection.id, e + ); + connection.mark_error(); + return Err(e.into()); + } + } + debug!( + "Successfully resubscribed {} feeds on connection {}", + subscriptions.len(), + connection.id + ); + } + + // Wait for the message handler to complete + match message_handler.await { + Ok(_) => Ok(()), + Err(e) => { + error!("Message handler task panicked: {}", e); + Err(anyhow!("Message handler task failed")) + } + } + } + + async fn run_connection_loop(&self, mut connection: ConnectionState) { + loop { + if let Err(e) = self.connect_with_backoff(&mut connection).await { + error!("Connection {} failed permanently: {}", connection.id, e); + tokio::time::sleep(RECONNECT_WAIT).await; + continue; + } + + // If we get here, the connection was closed gracefully + warn!( + "Connection {} closed, waiting {} seconds before reconnecting...", + connection.id, + RECONNECT_WAIT.as_secs() + ); + tokio::time::sleep(RECONNECT_WAIT).await; + } + } + + pub async fn connect(&mut self) -> Result<()> { + let num_connections = self.urls.len().min(MAX_NUM_CONNECTIONS); + let mut connections = Vec::with_capacity(num_connections); + + for i in 0..num_connections { + let url = self.urls[i % self.urls.len()].clone(); + let connection = ConnectionState::new(i, url); + connections.push(Arc::new(Mutex::new(connection))); + } + + *self.connections.lock().await = connections.clone(); + + for (i, connection) in connections.into_iter().enumerate() { + let consumer = self.clone(); + + // Stagger connection attempts + if i > 0 { + tokio::time::sleep(RECONNECT_STAGGER).await; + } + + tokio::spawn(async move { + let conn = connection.lock().await.clone(); + consumer.run_connection_loop(conn).await; + }); + } + + Ok(()) + } + + /// Subscribe to price feed updates. + /// + /// # Arguments + /// * `subscription_id` - Unique identifier for this subscription + /// * `feed_ids` - List of price feed IDs to subscribe to + /// * `properties` - Optional list of properties to receive (defaults to [Price]) + /// * `chains` - Optional list of chains to receive updates for (defaults to [EVM, Solana]) + /// * `delivery_format` - Optional message format (defaults to JSON) + pub async fn subscribe( + &mut self, + subscription_id: u64, + feed_ids: Vec, + properties: Option>, + chains: Option>, + delivery_format: Option, + ) -> Result<()> { + let params = SubscriptionParams::new(SubscriptionParamsRepr { + price_feed_ids: feed_ids, + properties: properties.unwrap_or_else(|| vec![PriceFeedProperty::Price]), + chains: chains.unwrap_or_else(|| vec![Chain::Evm, Chain::Solana]), + delivery_format: delivery_format.unwrap_or_default(), + json_binary_encoding: Default::default(), + parsed: true, + channel: Channel::FixedRate(FixedRate::MIN), + }) + .map_err(|e| anyhow!("Invalid subscription parameters: {}", e))?; + + let request = SubscribeRequest { + subscription_id: SubscriptionId(subscription_id), + params, + }; + + // Send subscription request through all active connections + let connections = self.connections.lock().await; + for connection in connections.iter() { + let mut conn = connection.lock().await; + if let Some(write) = &mut conn.write { + let msg = serde_json::to_string(&Request::Subscribe(request.clone()))?; + write.send(Message::Text(msg)).await?; + } + } + + self.active_subscriptions + .lock() + .await + .insert(subscription_id, request); + Ok(()) + } + + pub(crate) async fn handle_messages( + mut read: WsSource, + tx: broadcast::Sender, + cache: Arc>>, + ) -> Result<()> { + while let Some(msg) = read.next().await { + match msg { + Ok(Message::Text(text)) => { + match serde_json::from_str::(&text) { + Ok(response) => { + // Check if we've seen this message recently + let msg_key = format!("{:?}", &response); + let mut cache = cache.lock().await; + + if cache.contains_key(&msg_key) { + debug!("Dropping duplicate message"); + continue; + } + + // Cache the message with TTL + cache.insert(msg_key, (), DEDUP_TTL); + + match &response { + Response::Error(err) => { + error!("Server error: {}", err.error); + return Err(anyhow!("Server error: {}", err.error)); + } + Response::SubscriptionError(err) => { + error!( + "Subscription error for ID {}: {}", + err.subscription_id.0, err.error + ); + return Err(anyhow!("Subscription error: {}", err.error)); + } + Response::StreamUpdated(update) => { + debug!( + "Received update for subscription {}", + update.subscription_id.0 + ); + } + _ => debug!("Received response: {:?}", response), + } + + if let Err(e) = tx.send(response.clone()) { + error!("Failed to forward message to ticker stream: {}", e); + return Err(anyhow!("Failed to forward message: {}", e)); + } + + // Also send to the stream pool for redundancy + if let Err(e) = tx.send(response) { + error!("Failed to forward message to stream pool: {}", e); + return Err(anyhow!("Failed to forward message: {}", e)); + } + } + Err(e) => { + error!("Failed to parse message: {}", e); + return Err(anyhow!("Failed to parse message: {}", e)); + } + } + } + Ok(Message::Close(frame)) => { + info!("WebSocket connection closed by server: {:?}", frame); + return Ok(()); + } + Ok(Message::Ping(_)) => { + debug!("Received ping"); + } + Ok(Message::Pong(_)) => { + debug!("Received pong"); + } + Err(e) => { + error!("WebSocket error: {}", e); + return Err(anyhow!("WebSocket error: {}", e)); + } + _ => {} + } + } + Ok(()) + } + + pub async fn unsubscribe(&mut self, subscription_id: u64) -> Result<()> { + let request = UnsubscribeRequest { + subscription_id: SubscriptionId(subscription_id), + }; + + // Send unsubscribe request through all active connections + let connections = self.connections.lock().await; + for connection in connections.iter() { + let mut conn = connection.lock().await; + if let Some(write) = &mut conn.write { + let msg = serde_json::to_string(&Request::Unsubscribe(request.clone()))?; + write.send(Message::Text(msg)).await?; + } + } + + self.active_subscriptions + .lock() + .await + .remove(&subscription_id); + Ok(()) + } + + pub fn subscribe_to_updates(&self) -> broadcast::Receiver { + self.tx.subscribe() + } + + /// Subscribe to the combined stream pool that receives messages from all connections + pub fn subscribe_to_stream_pool(&self) -> broadcast::Receiver { + self.stream_tx.subscribe() + } +} diff --git a/lazer/sdk/rust/consumer/src/lib.rs b/lazer/sdk/rust/consumer/src/lib.rs new file mode 100644 index 0000000000..1661ab62a0 --- /dev/null +++ b/lazer/sdk/rust/consumer/src/lib.rs @@ -0,0 +1,7 @@ +mod client; + +pub use client::{ConnectionState, PythLazerConsumer}; +pub use pyth_lazer_protocol::{ + router::{Chain, DeliveryFormat, PriceFeedId, PriceFeedProperty}, + subscription::Response, +}; diff --git a/lazer/sdk/rust/consumer/tests/integration_test.rs b/lazer/sdk/rust/consumer/tests/integration_test.rs new file mode 100644 index 0000000000..d2e5c6370a --- /dev/null +++ b/lazer/sdk/rust/consumer/tests/integration_test.rs @@ -0,0 +1,132 @@ +use { + anyhow::Result, + pyth_lazer_consumer::{Chain, DeliveryFormat, PriceFeedProperty, PythLazerConsumer, Response}, + pyth_lazer_protocol::{ + router::{JsonUpdate, PriceFeedId}, + subscription::{StreamUpdatedResponse, SubscribedResponse, SubscriptionId}, + }, + std::{sync::Arc, time::Duration}, + tokio::sync::Mutex, + tokio_tungstenite::tungstenite::Message, + ttl_cache::TtlCache, +}; + +// Test helper trait +#[cfg(test)] +trait TestHelpers { + fn handle_message(&self, message: Message); +} + +#[cfg(test)] +impl TestHelpers for PythLazerConsumer { + fn handle_message(&self, message: Message) { + if let Message::Text(text) = message { + if let Ok(response) = serde_json::from_str::(&text) { + let _ = self.get_tx().send(response); + } + } + } +} + +#[cfg(test)] +#[tokio::test] +async fn test_subscription_lifecycle() -> Result<()> { + let mut consumer = PythLazerConsumer::new( + vec!["wss://test.pyth.network".to_string()], + "test_token".to_string(), + ) + .await?; + + // Connect before testing subscriptions + consumer.connect().await?; + + // Test subscription + let subscription_id = 1; + let feed_ids = vec![PriceFeedId(1)]; + let properties = Some(vec![PriceFeedProperty::Price]); + let chains = Some(vec![Chain::Evm]); + let delivery_format = Some(DeliveryFormat::Json); + + consumer + .subscribe( + subscription_id, + feed_ids, + properties, + chains, + delivery_format, + ) + .await?; + + // Verify subscription was created + let mut rx = consumer.subscribe_to_updates(); + + // Simulate subscription confirmation + let subscribed_response = Response::Subscribed(SubscribedResponse { + subscription_id: SubscriptionId(1), + }); + let confirmation = Message::Text(serde_json::to_string(&subscribed_response).unwrap()); + consumer.handle_message(confirmation); + + // Wait a short time for the message to be processed + tokio::time::sleep(Duration::from_millis(100)).await; + + // Try to receive the confirmation message + let result = rx.try_recv(); + assert!( + result.is_ok(), + "Expected to receive subscription confirmation" + ); + + if let Ok(Response::Subscribed(response)) = result { + assert_eq!(response.subscription_id, SubscriptionId(1)); + } else { + panic!("Expected Subscribed response"); + } + + // Test unsubscribe + consumer.unsubscribe(subscription_id).await?; + let rx = consumer.subscribe_to_updates(); + assert_eq!(rx.len(), 0); + + Ok(()) +} + +#[cfg(test)] +#[tokio::test] +async fn test_message_deduplication() -> Result<()> { + let mut consumer = PythLazerConsumer::new( + vec!["wss://test.pyth.network".to_string()], + "test_token".to_string(), + ) + .await?; + + // Connect and subscribe to receive messages + consumer.connect().await?; + let mut rx = consumer.subscribe_to_updates(); + + // Create a test stream update message + let stream_update = Response::StreamUpdated(StreamUpdatedResponse { + subscription_id: SubscriptionId(1), + payload: JsonUpdate { + parsed: None, + evm: None, + solana: None, + }, + }); + + let message = Message::Text(serde_json::to_string(&stream_update).unwrap()); + let cache = Arc::new(Mutex::new(TtlCache::new(100))); + + consumer.process_message(message.clone(), &cache).await?; + consumer.process_message(message, &cache).await?; + + // Wait for the first message + let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await; + assert!(result.is_ok(), "Expected to receive first message"); + + // Try to receive a second message (should fail due to deduplication) + let result = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await; + assert!(result.is_err(), "Should not receive duplicate message"); + + Ok(()) +}