Skip to content

Commit 0a74b3e

Browse files
committed
refactor(sender_allocation_task): add receipt validation
1 parent b4a6363 commit 0a74b3e

File tree

1 file changed

+148
-20
lines changed

1 file changed

+148
-20
lines changed

crates/tap-agent/src/agent/sender_allocation_task.rs

Lines changed: 148 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ pub struct SenderAllocationTask<T: NetworkVersion> {
3737
_phantom: PhantomData<T>,
3838
}
3939

40-
/// Simple state structure for the task
40+
/// Simple state structure for the task (will be enhanced incrementally)
4141
struct TaskState {
4242
/// Sum of all receipt fees for the current allocation
4343
unaggregated_fees: UnaggregatedReceipts,
@@ -115,27 +115,39 @@ impl<T: NetworkVersion> SenderAllocationTask<T> {
115115
Ok(())
116116
}
117117

118-
/// Handle new receipt - simplified version
118+
/// Handle new receipt - with basic validation
119119
async fn handle_new_receipt(
120120
state: &mut TaskState,
121121
notification: NewReceiptNotification,
122122
) -> anyhow::Result<()> {
123-
// For now, just accept all receipts as valid
124-
// In the full implementation, this would include validation
125-
126-
let value = match notification {
127-
NewReceiptNotification::V1(ref n) => n.value,
128-
NewReceiptNotification::V2(ref n) => n.value,
123+
let (id, value, timestamp_ns) = match notification {
124+
NewReceiptNotification::V1(ref n) => (n.id, n.value, n.timestamp_ns),
125+
NewReceiptNotification::V2(ref n) => (n.id, n.value, n.timestamp_ns),
129126
};
130127

131-
let timestamp_ns = match notification {
132-
NewReceiptNotification::V1(ref n) => n.timestamp_ns,
133-
NewReceiptNotification::V2(ref n) => n.timestamp_ns,
134-
};
128+
// Basic receipt ID validation - reject already processed receipts
129+
if id <= state.unaggregated_fees.last_id {
130+
tracing::warn!(
131+
allocation_id = ?state.allocation_id,
132+
receipt_id = id,
133+
last_processed_id = state.unaggregated_fees.last_id,
134+
"Rejecting receipt with ID <= last processed ID"
135+
);
136+
return Ok(()); // Silently ignore duplicate/old receipts
137+
}
135138

136-
// Update local state
139+
// Update local state with new receipt
137140
state.unaggregated_fees.value += value;
138141
state.unaggregated_fees.counter += 1;
142+
state.unaggregated_fees.last_id = id;
143+
144+
tracing::debug!(
145+
allocation_id = ?state.allocation_id,
146+
receipt_id = id,
147+
value = value,
148+
new_total = state.unaggregated_fees.value,
149+
"Processed new receipt"
150+
);
139151

140152
// Notify parent
141153
state
@@ -149,14 +161,28 @@ impl<T: NetworkVersion> SenderAllocationTask<T> {
149161
Ok(())
150162
}
151163

152-
/// Handle RAV request - simplified version
164+
/// Handle RAV request - enhanced but still simplified version
153165
async fn handle_rav_request(state: &mut TaskState) -> anyhow::Result<()> {
154-
let _start_time = Instant::now();
166+
let start_time = Instant::now();
167+
168+
// Check if there are any receipts to aggregate
169+
if state.unaggregated_fees.value == 0 {
170+
tracing::debug!(
171+
allocation_id = ?state.allocation_id,
172+
"No receipts to aggregate, skipping RAV request"
173+
);
174+
return Ok(());
175+
}
155176

156-
// For now, simulate a successful RAV request
157-
// In the full implementation, this would make actual gRPC calls
177+
tracing::info!(
178+
allocation_id = ?state.allocation_id,
179+
receipt_count = state.unaggregated_fees.counter,
180+
total_value = state.unaggregated_fees.value,
181+
"Creating RAV for aggregated receipts"
182+
);
158183

159-
// Create a dummy RAV info
184+
// TODO: Replace with real TAP manager integration
185+
// For now, simulate a successful RAV request
160186
let rav_info = RavInformation {
161187
allocation_id: match state.allocation_id {
162188
AllocationId::Legacy(id) => id.into_inner(),
@@ -168,15 +194,26 @@ impl<T: NetworkVersion> SenderAllocationTask<T> {
168194
value_aggregate: state.unaggregated_fees.value,
169195
};
170196

197+
// Store the fees we're about to clear for the response
198+
let fees_to_clear = state.unaggregated_fees;
199+
171200
// Reset local fees since they're now covered by RAV
172201
state.unaggregated_fees = UnaggregatedReceipts::default();
173202

203+
let elapsed = start_time.elapsed();
204+
tracing::info!(
205+
allocation_id = ?state.allocation_id,
206+
rav_value = rav_info.value_aggregate,
207+
duration_ms = elapsed.as_millis(),
208+
"RAV creation completed successfully"
209+
);
210+
174211
// Notify parent of successful RAV
175212
state
176213
.sender_account_handle
177214
.cast(SenderAccountMessage::UpdateReceiptFees(
178215
state.allocation_id,
179-
ReceiptFees::RavRequestResponse(state.unaggregated_fees, Ok(Some(rav_info))),
216+
ReceiptFees::RavRequestResponse(fees_to_clear, Ok(Some(rav_info))),
180217
))
181218
.await?;
182219

@@ -238,6 +275,97 @@ mod tests {
238275
.unwrap();
239276

240277
// Check it's the right type of message
241-
matches!(parent_message, SenderAccountMessage::UpdateReceiptFees(..));
278+
assert!(matches!(
279+
parent_message,
280+
SenderAccountMessage::UpdateReceiptFees(..)
281+
));
282+
}
283+
284+
#[tokio::test]
285+
async fn test_receipt_id_validation() {
286+
let lifecycle = LifecycleManager::new();
287+
288+
// Create a dummy parent handle for testing
289+
let (parent_tx, mut parent_rx) = mpsc::channel(10);
290+
let parent_handle = TaskHandle::new_for_test(
291+
parent_tx,
292+
Some("test_parent".to_string()),
293+
std::sync::Arc::new(lifecycle.clone()),
294+
);
295+
296+
let allocation_id =
297+
AllocationId::Legacy(thegraph_core::AllocationId::new([1u8; 20].into()));
298+
299+
let task_handle = SenderAllocationTask::<Legacy>::spawn_simple(
300+
&lifecycle,
301+
Some("test_allocation".to_string()),
302+
allocation_id,
303+
parent_handle,
304+
)
305+
.await
306+
.unwrap();
307+
308+
// Send first receipt (should be accepted)
309+
let notification1 = NewReceiptNotification::V1(
310+
super::super::sender_accounts_manager::NewReceiptNotificationV1 {
311+
id: 100,
312+
allocation_id: thegraph_core::AllocationId::new([1u8; 20].into()).into_inner(),
313+
signer_address: thegraph_core::alloy::primitives::Address::ZERO,
314+
timestamp_ns: 1000,
315+
value: 100,
316+
},
317+
);
318+
319+
task_handle
320+
.cast(SenderAllocationMessage::NewReceipt(notification1))
321+
.await
322+
.unwrap();
323+
324+
// Receive first update
325+
let _first_message = parent_rx.recv().await.unwrap();
326+
327+
// Send second receipt with same ID (should be rejected silently)
328+
let notification2 = NewReceiptNotification::V1(
329+
super::super::sender_accounts_manager::NewReceiptNotificationV1 {
330+
id: 100, // Same ID - should be rejected
331+
allocation_id: thegraph_core::AllocationId::new([1u8; 20].into()).into_inner(),
332+
signer_address: thegraph_core::alloy::primitives::Address::ZERO,
333+
timestamp_ns: 2000,
334+
value: 200,
335+
},
336+
);
337+
338+
task_handle
339+
.cast(SenderAllocationMessage::NewReceipt(notification2))
340+
.await
341+
.unwrap();
342+
343+
// Send third receipt with higher ID (should be accepted)
344+
let notification3 = NewReceiptNotification::V1(
345+
super::super::sender_accounts_manager::NewReceiptNotificationV1 {
346+
id: 101, // Higher ID - should be accepted
347+
allocation_id: thegraph_core::AllocationId::new([1u8; 20].into()).into_inner(),
348+
signer_address: thegraph_core::alloy::primitives::Address::ZERO,
349+
timestamp_ns: 3000,
350+
value: 300,
351+
},
352+
);
353+
354+
task_handle
355+
.cast(SenderAllocationMessage::NewReceipt(notification3))
356+
.await
357+
.unwrap();
358+
359+
// Should only receive one more message (for the third receipt)
360+
let second_message =
361+
tokio::time::timeout(std::time::Duration::from_millis(100), parent_rx.recv())
362+
.await
363+
.unwrap()
364+
.unwrap();
365+
366+
assert!(matches!(
367+
second_message,
368+
SenderAccountMessage::UpdateReceiptFees(..)
369+
));
242370
}
243371
}

0 commit comments

Comments
 (0)