|
| 1 | +use libp2p::gossipsub::{DataTransform, Message, RawMessage, TopicHash}; |
| 2 | +use std::io::{Error, ErrorKind}; |
| 3 | +use std::time::{SystemTime, UNIX_EPOCH}; |
| 4 | + |
| 5 | +/// A `DataTransform` implementation that adds & checks a timestamp to the message. |
| 6 | +pub struct TTLDataTransform { |
| 7 | + /// Time-to-live, e.g. obtained from some `duration.as_secs()`. |
| 8 | + ttl_secs: u64, |
| 9 | +} |
| 10 | + |
| 11 | +impl TTLDataTransform { |
| 12 | + const MID_SIZE: usize = 8; |
| 13 | + |
| 14 | + #[inline(always)] |
| 15 | + fn get_time_secs(&self) -> u64 { |
| 16 | + SystemTime::now() |
| 17 | + .duration_since(UNIX_EPOCH) |
| 18 | + .unwrap() |
| 19 | + .as_secs() |
| 20 | + } |
| 21 | +} |
| 22 | + |
| 23 | +impl DataTransform for TTLDataTransform { |
| 24 | + fn inbound_transform(&self, mut raw_message: RawMessage) -> Result<Message, Error> { |
| 25 | + // check length |
| 26 | + if raw_message.data.len() < Self::MID_SIZE { |
| 27 | + return Err(Error::new(ErrorKind::InvalidInput, "Message too short")); |
| 28 | + } |
| 29 | + |
| 30 | + // parse time |
| 31 | + let raw_data = raw_message.data.split_off(Self::MID_SIZE); |
| 32 | + let msg_time = u64::from_be_bytes(raw_message.data[0..Self::MID_SIZE].try_into().unwrap()); |
| 33 | + |
| 34 | + // check ttl |
| 35 | + if msg_time + self.ttl_secs < self.get_time_secs() { |
| 36 | + return Err(Error::new(ErrorKind::InvalidInput, "Message expired")); |
| 37 | + } |
| 38 | + |
| 39 | + Ok(Message { |
| 40 | + source: raw_message.source, |
| 41 | + data: raw_data, |
| 42 | + sequence_number: raw_message.sequence_number, |
| 43 | + topic: raw_message.topic, |
| 44 | + }) |
| 45 | + } |
| 46 | + |
| 47 | + fn outbound_transform( |
| 48 | + &self, |
| 49 | + _topic: &TopicHash, |
| 50 | + data: Vec<u8>, |
| 51 | + ) -> Result<Vec<u8>, std::io::Error> { |
| 52 | + let msg_time = self.get_time_secs().to_be_bytes(); |
| 53 | + |
| 54 | + // prepend time bytes to the data |
| 55 | + let mut transformed_data = Vec::with_capacity(Self::MID_SIZE + data.len()); |
| 56 | + transformed_data.extend_from_slice(&msg_time); |
| 57 | + transformed_data.extend_from_slice(&data); |
| 58 | + |
| 59 | + Ok(transformed_data) |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +#[cfg(test)] |
| 64 | +mod tests { |
| 65 | + use std::time::Duration; |
| 66 | + |
| 67 | + use super::*; |
| 68 | + |
| 69 | + #[test] |
| 70 | + fn test_ttl_data_transform() { |
| 71 | + let data = vec![1, 2, 3, 4, 5]; |
| 72 | + let ttl_secs = Duration::from_secs(100).as_secs(); |
| 73 | + let ttl_data_transform = TTLDataTransform { ttl_secs }; |
| 74 | + let topic = TopicHash::from_raw("topic"); |
| 75 | + |
| 76 | + // outbound transform |
| 77 | + let transformed_data = ttl_data_transform |
| 78 | + .outbound_transform(&topic, data.clone()) |
| 79 | + .unwrap(); |
| 80 | + |
| 81 | + // inbound transform |
| 82 | + let raw_message = RawMessage { |
| 83 | + source: Default::default(), |
| 84 | + data: transformed_data, |
| 85 | + sequence_number: None, |
| 86 | + topic, |
| 87 | + signature: Default::default(), |
| 88 | + key: Default::default(), |
| 89 | + validated: false, |
| 90 | + }; |
| 91 | + let message = ttl_data_transform.inbound_transform(raw_message).unwrap(); |
| 92 | + |
| 93 | + assert_eq!(message.data, data); |
| 94 | + } |
| 95 | +} |
0 commit comments