Skip to content

Commit 1942884

Browse files
committed
fix(hip-3-pusher): KMS support, various enhancements
1 parent 8979a0d commit 1942884

File tree

2 files changed

+59
-36
lines changed

2 files changed

+59
-36
lines changed
Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
import boto3
2-
from asn1crypto import core
2+
from cryptography.hazmat.primitives import serialization
3+
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
34
from eth_account.messages import encode_typed_data, _hash_eip191_message
5+
from eth_keys.backends.native.ecdsa import N as SECP256K1_N
46
from eth_keys.datatypes import Signature
57
from eth_utils import keccak, to_hex
68
from hyperliquid.exchange import Exchange
79
from hyperliquid.utils.constants import TESTNET_API_URL, MAINNET_API_URL
810
from hyperliquid.utils.signing import get_timestamp_ms, action_hash, construct_phantom_agent, l1_payload
911
from loguru import logger
1012

13+
SECP256K1_N_HALF = SECP256K1_N // 2
14+
1115

1216
class KMSSigner:
13-
def __init__(self, key_id, aws_region_name, use_testnet):
17+
def __init__(self, config):
18+
use_testnet = config["hyperliquid"]["use_testnet"]
1419
url = TESTNET_API_URL if use_testnet else MAINNET_API_URL
1520
self.oracle_publisher_exchange: Exchange = Exchange(wallet=None, base_url=url)
21+
self.client = self._init_client(config)
1622

17-
self.key_id = key_id
18-
self.client = boto3.client("kms", region_name=aws_region_name)
1923
# Fetch public key once so we can derive address and check recovery id
20-
pub_der = self.client.get_public_key(KeyId=key_id)["PublicKey"]
21-
22-
from cryptography.hazmat.primitives import serialization
23-
pub = serialization.load_der_public_key(pub_der)
24+
key_path = config["kms"]["key_path"]
25+
self.key_id = open(key_path, "r").read().strip()
26+
self.pubkey_der = self.client.get_public_key(KeyId=self.key_id)["PublicKey"]
27+
# Construct eth address to log
28+
pub = serialization.load_der_public_key(self.pubkey_der)
2429
numbers = pub.public_numbers()
2530
x = numbers.x.to_bytes(32, "big")
2631
y = numbers.y.to_bytes(32, "big")
@@ -29,6 +34,22 @@ def __init__(self, key_id, aws_region_name, use_testnet):
2934
self.address = "0x" + keccak(uncompressed[1:])[-20:].hex()
3035
logger.info("KMSSigner address: {}", self.address)
3136

37+
def _init_client(self, config):
38+
aws_region_name = config["kms"]["aws_region_name"]
39+
access_key_id_path = config["kms"]["access_key_id_path"]
40+
access_key_id = open(access_key_id_path, "r").read().strip()
41+
secret_access_key_path = config["kms"]["secret_access_key_path"]
42+
secret_access_key = open(secret_access_key_path, "r").read().strip()
43+
44+
return boto3.client(
45+
"kms",
46+
region_name=aws_region_name,
47+
aws_access_key_id=access_key_id,
48+
aws_secret_access_key=secret_access_key,
49+
# can specify an endpoint for e.g. LocalStack
50+
# endpoint_url="http://localhost:4566"
51+
)
52+
3253
def set_oracle(self, dex, oracle_pxs, all_mark_pxs, external_perp_pxs):
3354
timestamp = get_timestamp_ms()
3455
oracle_pxs_wire = sorted(list(oracle_pxs.items()))
@@ -60,34 +81,39 @@ def sign_l1_action(self, action, nonce, is_mainnet):
6081
data = l1_payload(phantom_agent)
6182
structured_data = encode_typed_data(full_message=data)
6283
message_hash = _hash_eip191_message(structured_data)
63-
signed = self.sign_message(message_hash)
64-
return {"r": to_hex(signed["r"]), "s": to_hex(signed["s"]), "v": signed["v"]}
84+
return self.sign_message(message_hash)
6585

66-
def sign_message(self, message_hash: bytes):
86+
def sign_message(self, message_hash: bytes) -> dict:
87+
# Send message hash to KMS for signing
6788
resp = self.client.sign(
6889
KeyId=self.key_id,
6990
Message=message_hash,
7091
MessageType="DIGEST",
7192
SigningAlgorithm="ECDSA_SHA_256", # required for secp256k1
7293
)
73-
der_sig = resp["Signature"]
74-
75-
seq = core.Sequence.load(der_sig)
76-
r = int(seq[0].native)
77-
s = int(seq[1].native)
78-
79-
for recovery_id in (0, 1):
80-
candidate = Signature(vrs=(recovery_id, r, s))
81-
pubkey = candidate.recover_public_key_from_msg_hash(message_hash)
82-
if pubkey.to_bytes() == self.public_key_bytes:
83-
v = recovery_id + 27
84-
break
85-
else:
86-
raise ValueError("Failed to determine recovery id")
87-
88-
return {
89-
"r": r,
90-
"s": s,
91-
"v": v,
92-
"signature": Signature(vrs=(v, r, s)).to_bytes().hex(),
93-
}
94+
kms_signature = resp["Signature"]
95+
# Decode the KMS DER signature -> (r, s)
96+
r, s = decode_dss_signature(kms_signature)
97+
# Ethereum requires low-s form
98+
if s > SECP256K1_N_HALF:
99+
s = SECP256K1_N - s
100+
# Parse KMS public key into uncompressed secp256k1 bytes
101+
# TODO: Pull this into init
102+
pubkey = serialization.load_der_public_key(self.pubkey_der)
103+
pubkey_bytes = pubkey.public_bytes(
104+
serialization.Encoding.X962,
105+
serialization.PublicFormat.UncompressedPoint,
106+
)
107+
# Strip leading 0x04 (uncompressed point indicator)
108+
raw_pubkey_bytes = pubkey_bytes[1:]
109+
# Try both recovery ids
110+
for v in (0, 1):
111+
sig_obj = Signature(vrs=(v, r, s))
112+
recovered_pub = sig_obj.recover_public_key_from_msg_hash(message_hash)
113+
if recovered_pub.to_bytes() == raw_pubkey_bytes:
114+
return {
115+
"r": to_hex(r),
116+
"s": to_hex(s),
117+
"v": v + 27,
118+
}
119+
raise ValueError("Could not recover public key; signature mismatch")

apps/hip-3-pusher/src/publisher.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,11 @@ def __init__(self, config: dict, price_state: PriceState, metrics: Metrics):
2727
if config["kms"]["enable_kms"]:
2828
self.enable_kms = True
2929
oracle_account = None
30-
kms_key_path = config["kms"]["key_path"]
31-
kms_key_id = open(kms_key_path, "r").read().strip()
32-
self.kms_signer = KMSSigner(kms_key_id, config["kms"]["aws_region_name"], self.use_testnet)
30+
self.kms_signer = KMSSigner(config)
3331
else:
3432
oracle_pusher_key_path = config["hyperliquid"]["oracle_pusher_key_path"]
3533
oracle_pusher_key = open(oracle_pusher_key_path, "r").read().strip()
3634
oracle_account: LocalAccount = Account.from_key(oracle_pusher_key)
37-
del oracle_pusher_key
3835
logger.info("oracle pusher local pubkey: {}", oracle_account.address)
3936

4037
url = TESTNET_API_URL if self.use_testnet else MAINNET_API_URL

0 commit comments

Comments
 (0)