Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion payjoin/src/core/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ pub(crate) use error::InternalSessionError;
pub use error::SessionError;
use serde::de::Deserializer;
use serde::{Deserialize, Serialize};
pub use session::{replay_event_log, SessionEvent, SessionHistory, SessionOutcome, SessionStatus};
pub use session::{
replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome,
SessionStatus,
};
use url::Url;
#[cfg(target_arch = "wasm32")]
use web_time::Duration;
Expand Down
198 changes: 135 additions & 63 deletions payjoin/src/core/receive/v2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,38 @@ use serde::{Deserialize, Serialize};
use super::{ReceiveSession, SessionContext};
use crate::error::{InternalReplayError, ReplayError};
use crate::output_substitution::OutputSubstitution;
use crate::persist::SessionPersister;
use crate::persist::{AsyncSessionPersister, SessionPersister};
use crate::receive::{InputPair, JsonReply, OriginalPayload, PsbtContext};
use crate::{ImplementationError, PjUri};

fn replay_events(
mut logs: impl Iterator<Item = SessionEvent>,
) -> Result<(ReceiveSession, Vec<SessionEvent>), ReplayError<ReceiveSession, SessionEvent>> {
let first_event = logs.next().ok_or(InternalReplayError::NoEvents)?;
let mut session_events = vec![first_event.clone()];
let mut receiver = match first_event {
SessionEvent::Created(context) => ReceiveSession::new(context),
_ => return Err(InternalReplayError::InvalidEvent(Box::new(first_event), None).into()),
};

for event in logs {
session_events.push(event.clone());
receiver = receiver.process_event(event)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous code if process_event failed we close the session. Is this a regression?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks the close was move to the replay_events call side. Is that right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replay_events does no IO so that it can be reused in the async version of replay_event_log. So it just returns an error here and that's handled with a session close in replay_event_log/_async:

    let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
        Ok(r) => r,
        Err(e) => {
            persister.close().await.map_err(|ce| {
                InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
            })?;
            return Err(e);
        }
    };

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a second commit that adds test coverage for this scenario to confirm no regression

}
Ok((receiver, session_events))
}

fn construct_history(
session_events: Vec<SessionEvent>,
) -> Result<SessionHistory, ReplayError<ReceiveSession, SessionEvent>> {
let history = SessionHistory::new(session_events);
let ctx = history.session_context();
if ctx.expiration.elapsed() {
return Err(InternalReplayError::Expired(ctx.expiration).into());
}
Ok(history)
}

/// Replay a receiver event log to get the receiver in its current state [ReceiveSession]
/// and a session history [SessionHistory]
pub fn replay_event_log<P>(
Expand All @@ -17,35 +45,49 @@ where
P::SessionEvent: Into<SessionEvent> + Clone,
P::SessionEvent: From<SessionEvent>,
{
let mut logs = persister
let logs = persister
.load()
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

let first_event = logs.next().ok_or(InternalReplayError::NoEvents)?.into();
let mut session_events = vec![first_event.clone()];
let mut receiver = match first_event {
SessionEvent::Created(context) => ReceiveSession::new(context),
_ => return Err(InternalReplayError::InvalidEvent(Box::new(first_event), None).into()),
let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
Ok(r) => r,
Err(e) => {
persister.close().map_err(|ce| {
InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
})?;
return Err(e);
}
};
for event in logs {
session_events.push(event.clone().into());
receiver = receiver.process_event(event.into()).map_err(|e| {
if let Err(storage_err) = persister.close() {
return InternalReplayError::PersistenceFailure(ImplementationError::new(
storage_err,
))
.into();
}
e
})?;
}

let history = SessionHistory::new(session_events.clone());
let ctx = history.session_context();
if ctx.expiration.elapsed() {
return Err(InternalReplayError::Expired(ctx.expiration).into());
}
let history = construct_history(session_events)?;
Ok((receiver, history))
}

/// Async version of [replay_event_log]
pub async fn replay_event_log_async<P>(
persister: &P,
) -> Result<(ReceiveSession, SessionHistory), ReplayError<ReceiveSession, SessionEvent>>
where
P: AsyncSessionPersister,
P::SessionEvent: Into<SessionEvent> + Clone,
P::SessionEvent: From<SessionEvent>,
{
let logs = persister
.load()
.await
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
Ok(r) => r,
Err(e) => {
persister.close().await.map_err(|ce| {
InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
})?;
return Err(e);
}
};

let history = construct_history(session_events)?;
Ok((receiver, history))
}

Expand Down Expand Up @@ -183,7 +225,7 @@ mod tests {
use payjoin_test_utils::{BoxError, EXAMPLE_URL};

use super::*;
use crate::persist::test_utils::InMemoryTestPersister;
use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister};
use crate::persist::NoopSessionPersister;
use crate::receive::tests::original_from_test_vector;
use crate::receive::v2::test::{mock_err, SHARED_CONTEXT};
Expand Down Expand Up @@ -268,39 +310,55 @@ mod tests {
}
}

#[derive(Clone)]
struct SessionHistoryExpectedOutcome {
fallback_tx: Option<bitcoin::Transaction>,
expected_status: SessionStatus,
}

#[derive(Clone)]
struct SessionHistoryTest {
events: Vec<SessionEvent>,
expected_session_history: SessionHistoryExpectedOutcome,
expected_receiver_state: ReceiveSession,
}

fn run_session_history_test(test: SessionHistoryTest) -> Result<(), BoxError> {
let persister = InMemoryTestPersister::<SessionEvent>::default();
for event in test.events.clone() {
persister.save_event(event)?;
}

let (receiver, session_history) = replay_event_log(&persister)?;
fn verify_session_result(
session_result: Result<
(ReceiveSession, SessionHistory),
crate::error::ReplayError<ReceiveSession, SessionEvent>,
>,
test: &SessionHistoryTest,
) {
let (receiver, session_history) = session_result.expect("replay should succeed");
assert_eq!(receiver, test.expected_receiver_state);
assert_eq!(session_history.fallback_tx(), test.expected_session_history.fallback_tx);
assert_eq!(session_history.status(), test.expected_session_history.expected_status);
let expected_reply_key = test.events.iter().find_map(|event| match event {
SessionEvent::RetrievedOriginalPayload { reply_key, .. } => reply_key.clone(),
_ => None,
});

assert_eq!(session_history.session_context().reply_key, expected_reply_key);
}

Ok(())
fn run_session_history_test(test: &SessionHistoryTest) {
let persister = InMemoryTestPersister::<SessionEvent>::default();
for event in test.events.clone() {
persister.save_event(event).expect("In memory persister shouldn't fail");
}
verify_session_result(replay_event_log(&persister), test);
}

#[test]
fn test_replaying_session_creation() -> Result<(), BoxError> {
async fn run_session_history_test_async(test: &SessionHistoryTest) {
let persister = InMemoryAsyncTestPersister::<SessionEvent>::default();
for event in test.events.clone() {
persister.save_event(event).await.expect("In memory persister shouldn't fail");
}
verify_session_result(replay_event_log_async(&persister).await, test);
}

#[tokio::test]
async fn test_replaying_session_creation() {
let session_context = SHARED_CONTEXT.clone();
let test = SessionHistoryTest {
events: vec![SessionEvent::Created(session_context.clone())],
Expand All @@ -313,25 +371,32 @@ mod tests {
session_context,
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_replaying_session_creation_with_expired_session() -> Result<(), BoxError> {
#[tokio::test]
async fn test_replaying_session_creation_with_expired_session() {
let expiration = (SystemTime::now() - Duration::from_secs(1)).try_into().unwrap();
let session_context = SessionContext { expiration, ..SHARED_CONTEXT.clone() };
let persister = InMemoryTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::Created(session_context))?;

let persister = InMemoryTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::Created(session_context.clone()));
let err = replay_event_log(&persister).expect_err("session should be expired");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::Expired(expiration).into();
assert_eq!(err.to_string(), expected_err.to_string());
Ok(())

let persister = InMemoryAsyncTestPersister::<SessionEvent>::default();
persister.save_event(SessionEvent::Created(session_context)).await;
let err = replay_event_log_async(&persister).await.expect_err("session should be expired");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::Expired(expiration).into();
assert_eq!(err.to_string(), expected_err.to_string());
}

#[test]
fn test_replaying_unchecked_proposal() -> Result<(), BoxError> {
#[tokio::test]
async fn test_replaying_unchecked_proposal() {
let session_context = SHARED_CONTEXT.clone();
let original = original_from_test_vector();
let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1);
Expand All @@ -353,11 +418,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_replaying_unchecked_proposal_with_reply_key() -> Result<(), BoxError> {
#[tokio::test]
async fn test_replaying_unchecked_proposal_with_reply_key() {
let session_context = SHARED_CONTEXT.clone();
let original = original_from_test_vector();
let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1);
Expand All @@ -379,11 +445,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn getting_fallback_tx() -> Result<(), BoxError> {
#[tokio::test]
async fn getting_fallback_tx() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -413,11 +480,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_contributed_inputs() -> Result<(), BoxError> {
#[tokio::test]
async fn test_contributed_inputs() {
let persister = InMemoryTestPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -486,11 +554,12 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_payjoin_proposal() -> Result<(), BoxError> {
#[tokio::test]
async fn test_payjoin_proposal() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -561,11 +630,12 @@ mod tests {
},
expected_receiver_state: ReceiveSession::Closed(SessionOutcome::Success(vec![])),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_session_fatal_error() -> Result<(), BoxError> {
#[tokio::test]
async fn test_session_fatal_error() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -596,11 +666,12 @@ mod tests {
},
expected_receiver_state: ReceiveSession::Closed(SessionOutcome::Failure),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
fn test_session_transient_error() -> Result<(), BoxError> {
#[tokio::test]
async fn test_session_transient_error() {
let persister = NoopSessionPersister::<SessionEvent>::default();
let session_context = SHARED_CONTEXT.clone();
let mut events = vec![];
Expand Down Expand Up @@ -632,7 +703,8 @@ mod tests {
session_context: SessionContext { reply_key, ..session_context },
}),
};
run_session_history_test(test)
run_session_history_test(&test);
run_session_history_test_async(&test).await;
}

#[test]
Expand Down
5 changes: 4 additions & 1 deletion payjoin/src/core/send/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ pub use error::{CreateRequestError, EncapsulationError};
use error::{InternalCreateRequestError, InternalEncapsulationError};
use ohttp::ClientResponse;
use serde::{Deserialize, Serialize};
pub use session::{replay_event_log, SessionEvent, SessionHistory, SessionOutcome, SessionStatus};
pub use session::{
replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome,
SessionStatus,
};
use url::Url;

use super::error::BuildSenderError;
Expand Down
Loading
Loading