Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion payjoin/src/core/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ pub(crate) use error::InternalSessionError;
pub use error::SessionError;
use serde::de::Deserializer;
use serde::{Deserialize, Serialize};
pub use session::{replay_event_log, SessionEvent, SessionHistory, SessionOutcome, SessionStatus};
pub use session::{
replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome,
SessionStatus,
};
use url::Url;
#[cfg(target_arch = "wasm32")]
use web_time::Duration;
Expand Down
228 changes: 165 additions & 63 deletions payjoin/src/core/receive/v2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,38 @@ use serde::{Deserialize, Serialize};
use super::{ReceiveSession, SessionContext};
use crate::error::{InternalReplayError, ReplayError};
use crate::output_substitution::OutputSubstitution;
use crate::persist::SessionPersister;
use crate::persist::{AsyncSessionPersister, SessionPersister};
use crate::receive::{InputPair, JsonReply, OriginalPayload, PsbtContext};
use crate::{ImplementationError, PjUri};

fn replay_events(
mut logs: impl Iterator<Item = SessionEvent>,
) -> Result<(ReceiveSession, Vec<SessionEvent>), ReplayError<ReceiveSession, SessionEvent>> {
let first_event = logs.next().ok_or(InternalReplayError::NoEvents)?;
let mut session_events = vec![first_event.clone()];
let mut receiver = match first_event {
SessionEvent::Created(context) => ReceiveSession::new(context),
_ => return Err(InternalReplayError::InvalidEvent(Box::new(first_event), None).into()),
};

for event in logs {
session_events.push(event.clone());
receiver = receiver.process_event(event)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous code if process_event failed we close the session. Is this a regression?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks the close was move to the replay_events call side. Is that right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replay_events does no IO so that it can be reused in the async version of replay_event_log. So it just returns an error here and that's handled with a session close in replay_event_log/_async:

    let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
        Ok(r) => r,
        Err(e) => {
            persister.close().await.map_err(|ce| {
                InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
            })?;
            return Err(e);
        }
    };

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a second commit that adds test coverage for this scenario to confirm no regression

}
Ok((receiver, session_events))
}

fn construct_history(
session_events: Vec<SessionEvent>,
) -> Result<SessionHistory, ReplayError<ReceiveSession, SessionEvent>> {
let history = SessionHistory::new(session_events);
let ctx = history.session_context();
if ctx.expiration.elapsed() {
return Err(InternalReplayError::Expired(ctx.expiration).into());
}
Ok(history)
}

/// Replay a receiver event log to get the receiver in its current state [ReceiveSession]
/// and a session history [SessionHistory]
pub fn replay_event_log<P>(
Expand All @@ -17,35 +45,49 @@ where
P::SessionEvent: Into<SessionEvent> + Clone,
P::SessionEvent: From<SessionEvent>,
{
let mut logs = persister
let logs = persister
.load()
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

let first_event = logs.next().ok_or(InternalReplayError::NoEvents)?.into();
let mut session_events = vec![first_event.clone()];
let mut receiver = match first_event {
SessionEvent::Created(context) => ReceiveSession::new(context),
_ => return Err(InternalReplayError::InvalidEvent(Box::new(first_event), None).into()),
let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
Ok(r) => r,
Err(e) => {
persister.close().map_err(|ce| {
InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
})?;
return Err(e);
}
};
for event in logs {
session_events.push(event.clone().into());
receiver = receiver.process_event(event.into()).map_err(|e| {
if let Err(storage_err) = persister.close() {
return InternalReplayError::PersistenceFailure(ImplementationError::new(
storage_err,
))
.into();
}
e
})?;
}

let history = SessionHistory::new(session_events.clone());
let ctx = history.session_context();
if ctx.expiration.elapsed() {
return Err(InternalReplayError::Expired(ctx.expiration).into());
}
let history = construct_history(session_events)?;
Ok((receiver, history))
}

/// Async version of [replay_event_log]
pub async fn replay_event_log_async<P>(
persister: &P,
) -> Result<(ReceiveSession, SessionHistory), ReplayError<ReceiveSession, SessionEvent>>
where
P: AsyncSessionPersister,
P::SessionEvent: Into<SessionEvent> + Clone,
P::SessionEvent: From<SessionEvent>,
{
let logs = persister
.load()
.await
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
Ok(r) => r,
Err(e) => {
persister.close().await.map_err(|ce| {
InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
})?;
return Err(e);
}
};

let history = construct_history(session_events)?;
Ok((receiver, history))
}

Expand Down Expand Up @@ -183,7 +225,7 @@ mod tests {
use payjoin_test_utils::{BoxError, EXAMPLE_URL};

use super::*;
use crate::persist::test_utils::InMemoryTestPersister;
use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister};
use crate::persist::NoopSessionPersister;
use crate::receive::tests::original_from_test_vector;
use crate::receive::v2::test::{mock_err, SHARED_CONTEXT};
Expand Down Expand Up @@ -268,39 +310,55 @@ mod tests {
}
}

#[derive(Clone)]
struct SessionHistoryExpectedOutcome {
fallback_tx: Option<bitcoin::Transaction>,
expected_status: SessionStatus,
}

#[derive(Clone)]
struct SessionHistoryTest {
events: Vec<SessionEvent>,
expected_session_history: SessionHistoryExpectedOutcome,
expected_receiver_state: ReceiveSession,
}

fn run_session_history_test(test: SessionHistoryTest) -> Result<(), BoxError> {
let persister = InMemoryTestPersister::<SessionEvent>::default();
for event in test.events.clone() {
persister.save_event(event)?;
}

let (receiver, session_history) = replay_event_log(&persister)?;
fn verify_session_result(
session_result: Result<
(ReceiveSession, SessionHistory),
crate::error::ReplayError<ReceiveSession, SessionEvent>,
>,
test: &SessionHistoryTest,
) {
let (receiver, session_history) = session_result.expect("replay should succeed");
assert_eq!(receiver, test.expected_receiver_state);
assert_eq!(session_history.fallback_tx(), test.expected_session_history.fallback_tx);
assert_eq!(session_history.status(), test.expected_session_history.expected_status);
let expected_reply_key = test.events.iter().find_map(|event| match event {
SessionEvent::RetrievedOriginalPayload { reply_key, .. } => reply_key.clone(),
_ => None,
});

assert_eq!(session_history.session_context().reply_key, expected_reply_key);
}

Ok(())
fn run_session_history_test(test: &SessionHistoryTest) {
let persister = InMemoryTestPersister::<SessionEvent>::default();
for event in test.events.clone() {
persister.save_event(event).expect("In memory persister shouldn't fail");
}
verify_session_result(replay_event_log(&persister), test);
}

#[test]
fn test_replaying_session_creation() -> Result<(), BoxError> {
async fn run_session_history_test_async(test: &SessionHistoryTest) {
let persister = InMemoryAsyncTestPersister::<SessionEvent>::default();
for event in test.events.clone() {
persister.save_event(event).await.expect("In memory persister shouldn't fail");
}
verify_session_result(replay_event_log_async(&persister).await, test);
}

#[tokio::test]
async fn test_replaying_session_creation() {
let session_context = SHARED_CONTEXT.clone();
let test = SessionHistoryTest {
events: vec![SessionEvent::Created(session_context.clone())],
Expand All @@ -313,25 +371,62 @@ mod tests {
session_context,
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_replaying_session_creation_with_expired_session() -> Result<(), BoxError> {
#[tokio::test]
async fn test_replaying_session_creation_with_expired_session() {
let expiration = (SystemTime::now() - Duration::from_secs(1)).try_into().unwrap();
let session_context = SessionContext { expiration, ..SHARED_CONTEXT.clone() };
let persister = InMemoryTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::Created(session_context))?;

let persister = InMemoryTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::Created(session_context.clone()));
let err = replay_event_log(&persister).expect_err("session should be expired");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::Expired(expiration).into();
assert_eq!(err.to_string(), expected_err.to_string());
Ok(())

let persister = InMemoryAsyncTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::Created(session_context)).await;
let err = replay_event_log_async(&persister).await.expect_err("session should be expired");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::Expired(expiration).into();
assert_eq!(err.to_string(), expected_err.to_string());
}

#[test]
fn test_replaying_unchecked_proposal() -> Result<(), BoxError> {
#[tokio::test]
async fn test_replaying_session_with_missing_created_event() {
let persister = InMemoryTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::CheckedBroadcastSuitability());
assert!(!persister.inner.read().expect("session read should succeed").is_closed);
let err = replay_event_log(&persister).expect_err("session replay should be fail");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::InvalidEvent(
Box::new(SessionEvent::CheckedBroadcastSuitability()),
None,
)
.into();
assert_eq!(err.to_string(), expected_err.to_string());
assert!(persister.inner.read().expect("lock should not be poisoned").is_closed);

let persister = InMemoryAsyncTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::CheckedBroadcastSuitability()).await;
assert!(!persister.inner.read().await.is_closed);
let err =
replay_event_log_async(&persister).await.expect_err("session replay should be fail");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::InvalidEvent(
Box::new(SessionEvent::CheckedBroadcastSuitability()),
None,
)
.into();
assert_eq!(err.to_string(), expected_err.to_string());
assert!(persister.inner.read().await.is_closed);
}

#[tokio::test]
async fn test_replaying_unchecked_proposal() {
let session_context = SHARED_CONTEXT.clone();
let original = original_from_test_vector();
let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1);
Expand All @@ -353,11 +448,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_replaying_unchecked_proposal_with_reply_key() -> Result<(), BoxError> {
#[tokio::test]
async fn test_replaying_unchecked_proposal_with_reply_key() {
let session_context = SHARED_CONTEXT.clone();
let original = original_from_test_vector();
let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1);
Expand All @@ -379,11 +475,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn getting_fallback_tx() -> Result<(), BoxError> {
#[tokio::test]
async fn getting_fallback_tx() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -413,11 +510,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_contributed_inputs() -> Result<(), BoxError> {
#[tokio::test]
async fn test_contributed_inputs() {
let persister = InMemoryTestPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -486,11 +584,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_payjoin_proposal() -> Result<(), BoxError> {
#[tokio::test]
async fn test_payjoin_proposal() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -561,11 +660,12 @@ mod tests {
},
expected_receiver_state: ReceiveSession::Closed(SessionOutcome::Success(vec![])),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_session_fatal_error() -> Result<(), BoxError> {
#[tokio::test]
async fn test_session_fatal_error() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -596,11 +696,12 @@ mod tests {
},
expected_receiver_state: ReceiveSession::Closed(SessionOutcome::Failure),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_session_transient_error() -> Result<(), BoxError> {
#[tokio::test]
async fn test_session_transient_error() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -632,7 +733,8 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
Expand Down
5 changes: 4 additions & 1 deletion payjoin/src/core/send/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ pub use error::{CreateRequestError, EncapsulationError};
use error::{InternalCreateRequestError, InternalEncapsulationError};
use ohttp::ClientResponse;
use serde::{Deserialize, Serialize};
pub use session::{replay_event_log, SessionEvent, SessionHistory, SessionOutcome, SessionStatus};
pub use session::{
replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome,
SessionStatus,
};
use url::Url;

use super::error::BuildSenderError;
Expand Down
Loading
Loading