Skip to content

Commit b3f0ca2

Browse files
committed
fix: jwe deserialization issues
Signed-off-by: Daniel Bluhm <[email protected]>
1 parent bc97b9f commit b3f0ca2

File tree

9 files changed

+188
-54
lines changed

9 files changed

+188
-54
lines changed

didcomm_messaging/crypto/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pydid import VerificationMethod
88

9-
from didcomm_messaging.jwe import JweEnvelope, from_b64url
9+
from didcomm_messaging.jwe import JweEnvelope
1010

1111

1212
class CryptoServiceError(Exception):
@@ -34,6 +34,7 @@ def multikey(self) -> str:
3434

3535
class SecretKey(ABC):
3636
"""Secret Key Type."""
37+
3738
@property
3839
@abstractmethod
3940
def kid(self) -> str:

didcomm_messaging/crypto/askar/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Askar backend for DIDComm Messaging."""
22
from collections import OrderedDict
33
import json
4-
from typing import Mapping, Optional, Sequence, Union
4+
from typing import Optional, Sequence, Union
55

66
from pydid import VerificationMethod
77
from didcomm_messaging.crypto import SecretsManager
@@ -406,8 +406,8 @@ async def ecdh_1pu_decrypt(
406406
recip.encrypted_key,
407407
cc_tag=wrapper.tag,
408408
)
409-
except AskarError:
410-
raise CryptoServiceError("Error decrypting content encryption key")
409+
except AskarError as err:
410+
raise CryptoServiceError("Error decrypting content encryption key") from err
411411

412412
try:
413413
plaintext = cek.aead_decrypt(

didcomm_messaging/crypto/basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Basic Crypto Implementations."""
2+
3+
from typing import Optional
4+
from . import S, SecretsManager
5+
6+
7+
class InMemorySecretsManager(SecretsManager[S]):
8+
"""In Memory Secrets Manager."""
9+
10+
def __init__(self, secrets: Optional[dict] = None):
11+
"""Initialize the InMemorySecretsManager."""
12+
self.secrets = secrets or {}
13+
14+
async def get_secret_by_kid(self, kid: str) -> Optional[S]:
15+
"""Get a secret by its kid."""
16+
return self.secrets.get(kid)
17+
18+
async def add_secret(self, secret: S) -> None:
19+
"""Add a secret to the secrets manager."""
20+
self.secrets[secret.kid] = secret

didcomm_messaging/didcomm.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
"""Class DIDComm Messaging interface."""
22

33

4-
from typing import Generic, Literal, NamedTuple, Optional, Sequence, Union
4+
from dataclasses import dataclass
5+
from typing import Generic, Literal, Optional, Sequence, Union
56

67
from pydid import DIDUrl, VerificationMethod
7-
from didcomm_messaging.crypto import P, S, CryptoService, SecretKey, SecretsManager
8+
from didcomm_messaging.crypto import P, S, CryptoService, SecretsManager
89
from didcomm_messaging.jwe import JweEnvelope, from_b64url
910
from didcomm_messaging.resolver import Resolver
1011

1112

12-
class PackedMessageMetadata(NamedTuple):
13+
@dataclass
14+
class PackedMessageMetadata(Generic[S]):
1315
"""Unpack result."""
1416

1517
wrapper: JweEnvelope
1618
method: Literal["ECDH-ES", "ECDH-1PU"]
17-
recip_key: SecretKey
19+
recip_key: S
1820
sender_kid: Optional[str]
1921

2022

@@ -36,7 +38,7 @@ def __init__(
3638
self.crypto = crypto
3739
self.secrets = secrets
3840

39-
async def extract_packed_message_metadata(
41+
async def extract_packed_message_metadata( # noqa: C901
4042
self, enc_message: Union[str, bytes]
4143
) -> PackedMessageMetadata:
4244
"""Extract metadata from a packed DIDComm message."""
@@ -101,9 +103,7 @@ async def unpack(self, enc_message: Union[str, bytes]) -> bytes:
101103
enc_message, metadata.recip_key, sender_key
102104
)
103105

104-
async def recip_for_kid_or_default_for_did(
105-
self, kid_or_did: str
106-
) -> P:
106+
async def recip_for_kid_or_default_for_did(self, kid_or_did: str) -> P:
107107
"""Resolve a verification method for a kid or return default recip."""
108108
if "#" in kid_or_did:
109109
vm = await self.resolver.resolve_and_dereference_verification_method(
@@ -128,6 +128,31 @@ async def recip_for_kid_or_default_for_did(
128128

129129
return self.crypto.verification_method_to_public_key(vm)
130130

131+
async def default_sender_kid_for_did(self, did: str) -> str:
132+
"""Determine the kid of the default sender key for a DID."""
133+
if "#" in did:
134+
return did
135+
136+
doc = await self.resolver.resolve_and_parse(did)
137+
if not doc.key_agreement:
138+
raise DIDCommMessagingError(
139+
"No key agreement methods found; cannot determine recipient"
140+
)
141+
142+
default = doc.key_agreement[0]
143+
if isinstance(default, DIDUrl):
144+
vm = doc.dereference(default)
145+
if not isinstance(vm, VerificationMethod):
146+
raise DIDCommMessagingError(
147+
f"Expected verification method, found: {type(vm)}"
148+
)
149+
else:
150+
vm = default
151+
152+
if not vm.id.did:
153+
return vm.id.as_absolute(vm.controller)
154+
return vm.id
155+
131156
async def pack(
132157
self,
133158
message: bytes,
@@ -136,33 +161,15 @@ async def pack(
136161
**options,
137162
):
138163
"""Pack a DIDComm message."""
139-
recip_keys = [
140-
await self.recip_for_kid_or_default_for_did(kid) for kid in to
141-
]
142-
sender_key = await self.secrets.get_secret_by_kid(frm) if frm else None
164+
recip_keys = [await self.recip_for_kid_or_default_for_did(kid) for kid in to]
165+
sender_kid = await self.default_sender_kid_for_did(frm) if frm else None
166+
sender_key = (
167+
await self.secrets.get_secret_by_kid(sender_kid) if sender_kid else None
168+
)
169+
if frm and not sender_key:
170+
raise DIDCommMessagingError("No sender key found")
143171

144172
if sender_key:
145-
return await self.crypto.ecdh_1pu_encrypt(
146-
recip_keys, sender_key, message
147-
)
173+
return await self.crypto.ecdh_1pu_encrypt(recip_keys, sender_key, message)
148174
else:
149175
return await self.crypto.ecdh_es_encrypt(recip_keys, message)
150-
151-
152-
async def main():
153-
from aries_askar import Store
154-
from didcomm_messaging.askar import AskarCryptoService, AskarSecretsManager
155-
from didcomm_messaging.resolver.peer import Peer2, Peer4
156-
from didcomm_messaging.resolver import PrefixResolver
157-
158-
store = await Store.open("sqlite:///:memory:")
159-
kms = AskarSecretsManager(store)
160-
crypto = AskarCryptoService()
161-
didcomm = DIDCommMessaging(
162-
PrefixResolver(
163-
{
164-
"did:peer:2": Peer2(),
165-
"did:peer:4": Peer4()
166-
}
167-
), crypto, kms
168-
)

didcomm_messaging/jwe.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ def deserialize(cls, entry: Mapping[str, Any]) -> "JweRecipient":
4646
if "encrypted_key" not in entry:
4747
raise ValueError("Invalid JWE recipient: missing encrypted_key")
4848

49+
encrypyted_key = from_b64url(entry["encrypted_key"])
50+
4951
if "header" in entry:
5052
if not isinstance(entry["header"], dict):
5153
raise ValueError("Invalid JWE recipient: invalid header")
52-
return cls(encrypted_key=entry["encrypted_key"], header=entry["header"])
53-
return cls(**entry)
54+
return cls(encrypted_key=encrypyted_key, header=entry["header"])
55+
56+
return cls(encrypted_key=encrypyted_key)
5457

5558
def serialize(self) -> dict:
5659
"""Serialize the JWE recipient to a mapping."""
@@ -75,6 +78,7 @@ def __init__(
7578
self._with_flatten_recipients = with_flatten_recipients
7679
self._recipients: List[JweRecipient] = []
7780
self._protected: Optional[OrderedDict] = None
81+
self._unprotected: Optional[OrderedDict] = None
7882
self._protected_b64: Optional[bytes] = None
7983
self._ciphertext: Optional[bytes] = None
8084
self._iv: Optional[bytes] = None
@@ -137,9 +141,6 @@ def build(self):
137141
if not self._tag:
138142
raise ValueError("Missing tag for JWE")
139143

140-
if not self._aad:
141-
raise ValueError("Missing additional authenticated data for JWE")
142-
143144
if not self._recipients:
144145
raise ValueError("Missing recipients for JWE")
145146

@@ -158,6 +159,8 @@ def build(self):
158159
aad=self._aad,
159160
unprotected=self._unprotected,
160161
recipients=self._recipients,
162+
with_flatten_recipients=self._with_flatten_recipients,
163+
with_protected_recipients=self._with_protected_recipients,
161164
)
162165

163166

@@ -172,7 +175,7 @@ class JweEnvelope:
172175
iv: bytes
173176
tag: bytes
174177
aad: Optional[bytes] = None
175-
unprotected: dict = field(default_factory=OrderedDict)
178+
unprotected: Optional[dict] = field(default_factory=OrderedDict)
176179
with_protected_recipients: bool = False
177180
with_flatten_recipients: bool = True
178181

@@ -185,7 +188,7 @@ def from_json(cls, message: Union[bytes, str]) -> "JweEnvelope":
185188
raise ValueError("Invalid JWE: not JSON")
186189

187190
@classmethod
188-
def deserialize(cls, message: Mapping[str, Any]) -> "JweEnvelope":
191+
def deserialize(cls, message: Mapping[str, Any]) -> "JweEnvelope": # noqa: C901
189192
"""Deserialize a JWE envelope from a mapping."""
190193
# Basic validation
191194

@@ -248,7 +251,7 @@ def deserialize(cls, message: Mapping[str, Any]) -> "JweEnvelope":
248251
return cls._deserialize(message)
249252

250253
@classmethod
251-
def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope":
254+
def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope": # noqa: C901
252255
protected_b64 = parsed[IDENT_PROTECTED]
253256
try:
254257
protected: dict = json.loads(from_b64url(protected_b64))
@@ -281,7 +284,7 @@ def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope":
281284
encrypted_key = parsed[IDENT_ENC_KEY]
282285
header = parsed.get(IDENT_HEADER)
283286
else:
284-
raise ValueError("Invalid JWE: missing encrypted key")
287+
header = None
285288

286289
if recipients:
287290
if encrypted_key:
@@ -303,22 +306,27 @@ def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope":
303306
if recip.header and recip.header.keys() & all_h:
304307
raise ValueError("Invalid JWE: duplicate header")
305308

309+
ciphertext = from_b64url(parsed["ciphertext"])
310+
iv = from_b64url(parsed["iv"])
311+
tag = from_b64url(parsed["tag"])
312+
aad = from_b64url(parsed["aad"]) if "aad" in parsed else None
313+
306314
inst = cls(
307315
recipients=recipients,
308316
protected=protected,
309317
protected_b64=protected_b64,
310318
unprotected=unprotected,
311-
ciphertext=parsed["ciphertext"],
312-
iv=parsed["iv"],
313-
tag=parsed["tag"],
314-
aad=parsed.get("aad"),
319+
ciphertext=ciphertext,
320+
iv=iv,
321+
tag=tag,
322+
aad=aad,
315323
with_protected_recipients=protected_recipients,
316324
with_flatten_recipients=flat_recipients,
317325
)
318326

319327
return inst
320328

321-
def serialize(self) -> dict:
329+
def serialize(self) -> dict: # noqa: C901
322330
"""Serialize the JWE envelope to a mapping."""
323331
if self.protected_b64 is None:
324332
raise ValueError("Missing protected: use set_protected")
@@ -329,7 +337,7 @@ def serialize(self) -> dict:
329337
if self.tag is None:
330338
raise ValueError("Missing tag for JWE")
331339
env = OrderedDict()
332-
env["protected"] = self.protected_b64
340+
env["protected"] = self.protected_b64.decode("utf-8")
333341
if self.unprotected:
334342
env["unprotected"] = self.unprotected.copy()
335343
if not self.with_protected_recipients:

example.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Example of using DIDComm Messaging."""
2+
3+
from aries_askar import Key, KeyAlg
4+
from didcomm_messaging.crypto.askar import AskarCryptoService, AskarSecretKey
5+
from didcomm_messaging.crypto.basic import InMemorySecretsManager
6+
from didcomm_messaging.didcomm import DIDCommMessaging
7+
from didcomm_messaging.multiformats import multibase
8+
from didcomm_messaging.multiformats import multicodec
9+
from didcomm_messaging.resolver.peer import Peer2, Peer4
10+
from didcomm_messaging.resolver import PrefixResolver
11+
from did_peer_2 import KeySpec, generate, json
12+
13+
14+
async def main():
15+
"""An example of using DIDComm Messaging."""
16+
secrets = InMemorySecretsManager()
17+
crypto = AskarCryptoService()
18+
didcomm = DIDCommMessaging(
19+
PrefixResolver({"did:peer:2": Peer2(), "did:peer:4": Peer4()}), crypto, secrets
20+
)
21+
verkey = Key.generate(KeyAlg.ED25519)
22+
xkey = Key.generate(KeyAlg.X25519)
23+
did = generate(
24+
[
25+
KeySpec.verification(
26+
multibase.encode(
27+
multicodec.wrap("ed25519-pub", verkey.get_public_bytes()),
28+
"base58btc",
29+
)
30+
),
31+
KeySpec.key_agreement(
32+
multibase.encode(
33+
multicodec.wrap("x25519-pub", xkey.get_public_bytes()), "base58btc"
34+
)
35+
),
36+
],
37+
[],
38+
)
39+
await secrets.add_secret(AskarSecretKey(verkey, f"{did}#key-1"))
40+
await secrets.add_secret(AskarSecretKey(xkey, f"{did}#key-2"))
41+
print(did)
42+
packed = await didcomm.pack(b"hello world", [did], did)
43+
print(json.dumps(json.loads(packed), indent=2))
44+
unpacked = await didcomm.unpack(packed)
45+
print(unpacked)
46+
47+
48+
if __name__ == "__main__":
49+
import asyncio
50+
51+
asyncio.run(main())

pdm.lock

Lines changed: 14 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@ dev = [
5050
"pre-commit>=3.5.0",
5151
"black>=23.10.1",
5252
"ruff>=0.1.3",
53+
"pytest-asyncio>=0.21.1",
5354
]

0 commit comments

Comments
 (0)