Skip to content

Commit 1f0e94a

Browse files
committed
even more tidying up
1 parent 923b805 commit 1f0e94a

File tree

6 files changed

+113
-189
lines changed

6 files changed

+113
-189
lines changed

src/handlers/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ pub trait ComputeHandler {
2323
/// A generic handler for DKN tasks.
2424
///
2525
/// Returns a `MessageAcceptance` value that tells the P2P client to accept the incoming message.
26+
///
27+
/// The handler has mutable reference to the compute node, and therefore can respond within the handler itself in any way it would like.
2628
async fn handle_compute(
2729
node: &mut DriaComputeNode,
2830
message: DKNMessage,

src/handlers/pingpong.rs

Lines changed: 4 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::{
33
DriaComputeNode,
44
};
55
use async_trait::async_trait;
6-
use eyre::Result;
6+
use eyre::{Context, Result};
77
use libp2p::gossipsub::MessageAcceptance;
88
use ollama_workflows::{Model, ModelProvider};
99
use serde::{Deserialize, Serialize};
@@ -34,7 +34,9 @@ impl ComputeHandler for PingpongHandler {
3434
node: &mut DriaComputeNode,
3535
message: DKNMessage,
3636
) -> Result<MessageAcceptance> {
37-
let pingpong = message.parse_payload::<PingpongPayload>(true)?;
37+
let pingpong = message
38+
.parse_payload::<PingpongPayload>(true)
39+
.wrap_err("Could not parse ping request")?;
3840

3941
// check deadline
4042
let current_time = get_current_time_nanos();
@@ -68,77 +70,3 @@ impl ComputeHandler for PingpongHandler {
6870
Ok(MessageAcceptance::Accept)
6971
}
7072
}
71-
72-
#[cfg(test)]
73-
mod tests {
74-
use crate::{
75-
utils::{
76-
crypto::{sha256hash, to_address},
77-
filter::TaskFilter,
78-
DKNMessage,
79-
},
80-
DriaComputeNodeConfig,
81-
};
82-
use fastbloom_rs::{FilterBuilder, Membership};
83-
use libsecp256k1::{recover, Message, PublicKey};
84-
85-
use super::PingpongPayload;
86-
87-
#[test]
88-
fn test_pingpong_payload() {
89-
let pk = PublicKey::parse_compressed(&hex_literal::hex!(
90-
"0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658"
91-
))
92-
.expect("Should parse public key");
93-
let message = DKNMessage {
94-
payload: "Y2RmODcyNDlhY2U3YzQ2MDIzYzNkMzBhOTc4ZWY3NjViMWVhZDlmNWJhMDUyY2MxMmY0NzIzMjQyYjc0YmYyODFjMDA1MTdmMGYzM2VkNTgzMzk1YWUzMTY1ODQ3NWQyNDRlODAxYzAxZDE5MjYwMDM1MTRkNzEwMThmYTJkNjEwMXsidXVpZCI6ICI4MWE2M2EzNC05NmM2LTRlNWEtOTliNS02YjI3NGQ5ZGUxNzUiLCAiZGVhZGxpbmUiOiAxNzE0MTI4NzkyfQ==".to_string(),
95-
topic: "pingpong".to_string(),
96-
version: "0.0.0".to_string(),
97-
timestamp: 1714129073557846272,
98-
};
99-
100-
assert!(message.is_signed(&pk).expect("Should check signature"));
101-
102-
let obj = message
103-
.parse_payload::<PingpongPayload>(true)
104-
.expect("Should parse payload");
105-
assert_eq!(obj.uuid, "81a63a34-96c6-4e5a-99b5-6b274d9de175");
106-
assert_eq!(obj.deadline, 1714128792);
107-
}
108-
109-
/// This test demonstrates the process of pingpong & task assignment.
110-
///
111-
/// A heart-beat message is sent over the network by Admin Node, and compute node responds with a signature.
112-
#[test]
113-
fn test_pingpong_and_task_assignment() {
114-
let config = DriaComputeNodeConfig::default();
115-
116-
// a pingpong message is signed and sent to Admin Node over the p2p network
117-
let pingpong_message = Message::parse(&sha256hash(b"sign-me"));
118-
let (pingpong_signature, pingpong_recid) =
119-
libsecp256k1::sign(&pingpong_message, &config.secret_key);
120-
121-
// admin recovers the address from the signature
122-
let recovered_public_key = recover(&pingpong_message, &pingpong_signature, &pingpong_recid)
123-
.expect("Could not recover");
124-
assert_eq!(
125-
config.public_key, recovered_public_key,
126-
"Public key mismatch"
127-
);
128-
let address = to_address(&recovered_public_key);
129-
assert_eq!(address, config.address, "Address mismatch");
130-
131-
// admin node assigns the task to the compute node via Bloom Filter
132-
let mut bloom = FilterBuilder::new(100, 0.01).build_bloom_filter();
133-
bloom.add(&address);
134-
let filter_payload = TaskFilter::from(bloom);
135-
136-
// compute node receives the filter and checks if it is tasked
137-
assert!(
138-
filter_payload
139-
.contains(&config.address)
140-
.expect("Should check filter"),
141-
"Node should be tasked"
142-
);
143-
}
144-
}

src/handlers/workflow.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_trait::async_trait;
22
use eyre::{eyre, Context, Result};
33
use libp2p::gossipsub::MessageAcceptance;
4+
use libsecp256k1::PublicKey;
45
use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
56
use serde::Deserialize;
67

@@ -37,7 +38,7 @@ impl ComputeHandler for WorkflowHandler {
3738
let config = &node.config;
3839
let task = message
3940
.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)
40-
.wrap_err("Could not parse error")?;
41+
.wrap_err("Could not parse workflow task")?;
4142

4243
// check if deadline is past or not
4344
let current_time = get_current_time_nanos();
@@ -97,8 +98,9 @@ impl ComputeHandler for WorkflowHandler {
9798
match exec_result {
9899
Ok(result) => {
99100
// obtain public key from the payload
100-
let task_public_key =
101+
let task_public_key_bytes =
101102
hex::decode(&task.public_key).wrap_err("Could not decode public key")?;
103+
let task_public_key = PublicKey::parse_slice(&task_public_key_bytes, None)?;
102104

103105
// prepare signed and encrypted payload
104106
let payload = TaskResponsePayload::new(

src/payloads/response.rs

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use crate::utils::crypto::sha256hash;
1+
use crate::utils::crypto::{encrypt_bytes, sha256hash, sign_bytes_recoverable};
22
use eyre::Result;
3-
use libsecp256k1::SecretKey;
3+
use libsecp256k1::{PublicKey, SecretKey};
44
use serde::{Deserialize, Serialize};
55

66
/// A computation task is the task of computing a result from a given input. The result is encrypted with the public key of the requester.
@@ -27,28 +27,68 @@ impl TaskResponsePayload {
2727
pub fn new(
2828
result: impl AsRef<[u8]>,
2929
task_id: &str,
30-
encrypting_public_key: &[u8],
30+
encrypting_public_key: &PublicKey,
3131
signing_secret_key: &SecretKey,
3232
) -> Result<Self> {
3333
// create the message `task_id || payload`
3434
let mut preimage = Vec::new();
3535
preimage.extend_from_slice(task_id.as_ref());
3636
preimage.extend_from_slice(result.as_ref());
3737

38-
// sign the message
39-
// TODO: use `sign_recoverable` here instead?
40-
let digest = libsecp256k1::Message::parse(&sha256hash(preimage));
41-
let (signature, recid) = libsecp256k1::sign(&digest, signing_secret_key);
42-
let signature: [u8; 64] = signature.serialize();
43-
let recid: [u8; 1] = [recid.serialize()];
44-
45-
// encrypt payload itself
46-
let ciphertext = ecies::encrypt(encrypting_public_key, result.as_ref())?;
38+
let signature = sign_bytes_recoverable(&sha256hash(preimage), signing_secret_key);
39+
let ciphertext = encrypt_bytes(result, encrypting_public_key)?;
4740

4841
Ok(TaskResponsePayload {
4942
ciphertext: hex::encode(ciphertext),
50-
signature: format!("{}{}", hex::encode(signature), hex::encode(recid)),
43+
signature,
5144
task_id: task_id.to_string(),
5245
})
5346
}
5447
}
48+
49+
#[cfg(test)]
50+
mod tests {
51+
use super::*;
52+
use ecies::decrypt;
53+
use libsecp256k1::{recover, verify, Message, PublicKey, RecoveryId, Signature};
54+
use rand::thread_rng;
55+
56+
#[test]
57+
fn test_task_response_payload() {
58+
// this is the result that we are "sending"
59+
const RESULT: &[u8; 44] = b"hey im an LLM and I came up with this output";
60+
61+
// the signer will sign the payload, and it will be verified
62+
let signer_sk = SecretKey::random(&mut thread_rng());
63+
let signer_pk = PublicKey::from_secret_key(&signer_sk);
64+
65+
// the payload will be encrypted with this key
66+
let task_sk = SecretKey::random(&mut thread_rng());
67+
let task_pk = PublicKey::from_secret_key(&task_sk);
68+
let task_id = uuid::Uuid::new_v4().to_string();
69+
70+
// creates a signed and encrypted payload
71+
let payload = TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk)
72+
.expect("Should create payload");
73+
74+
// decrypt result and compare it to plaintext
75+
let ciphertext_bytes = hex::decode(payload.ciphertext).unwrap();
76+
let result = decrypt(&task_sk.serialize(), &ciphertext_bytes).expect("Could not decrypt");
77+
assert_eq!(result, RESULT, "Result mismatch");
78+
79+
// verify signature
80+
let signature_bytes = hex::decode(payload.signature).expect("Should decode");
81+
let signature = Signature::parse_standard_slice(&signature_bytes[..64]).unwrap();
82+
let recid = RecoveryId::parse(signature_bytes[64]).unwrap();
83+
let mut preimage = vec![];
84+
preimage.extend_from_slice(task_id.as_bytes());
85+
preimage.extend_from_slice(&result);
86+
let message = Message::parse(&sha256hash(preimage));
87+
assert!(verify(&message, &signature, &signer_pk), "Could not verify");
88+
89+
// recover verifying key (public key) from signature
90+
let recovered_public_key =
91+
recover(&message, &signature, &recid).expect("Could not recover");
92+
assert_eq!(signer_pk, recovered_public_key, "Public key mismatch");
93+
}
94+
}

src/utils/crypto.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use ecies::PublicKey;
2+
use eyre::{Context, Result};
23
use libp2p_identity::Keypair;
3-
use libsecp256k1::{sign, Message, SecretKey};
4+
use libsecp256k1::{Message, SecretKey};
45
use sha2::{Digest, Sha256};
56
use sha3::Keccak256;
67

@@ -33,7 +34,7 @@ pub fn to_address(public_key: &PublicKey) -> [u8; 20] {
3334
#[inline]
3435
pub fn sign_bytes_recoverable(message: &[u8; 32], secret_key: &SecretKey) -> String {
3536
let message = Message::parse(message);
36-
let (signature, recid) = sign(&message, secret_key);
37+
let (signature, recid) = libsecp256k1::sign(&message, secret_key);
3738

3839
format!(
3940
"{}{}",
@@ -42,6 +43,15 @@ pub fn sign_bytes_recoverable(message: &[u8; 32], secret_key: &SecretKey) -> Str
4243
)
4344
}
4445

46+
/// Shorthand to encrypt bytes with a given public key.
47+
/// Returns hexadecimal encoded ciphertext.
48+
#[inline]
49+
pub fn encrypt_bytes(data: impl AsRef<[u8]>, public_key: &PublicKey) -> Result<String> {
50+
ecies::encrypt(public_key.serialize().as_slice(), data.as_ref())
51+
.wrap_err("could not encrypt data")
52+
.map(hex::encode)
53+
}
54+
4555
/// Converts a `libsecp256k1::SecretKey` to a `libp2p_identity::secp256k1::Keypair`.
4656
/// To do this, we serialize the secret key and create a new keypair from it.
4757
#[inline]

0 commit comments

Comments
 (0)