|
| 1 | +import boto3 |
| 2 | +from asn1crypto import core |
| 3 | +from eth_account.messages import encode_typed_data, _hash_eip191_message |
| 4 | +from eth_keys.datatypes import Signature |
| 5 | +from eth_utils import keccak, to_hex |
| 6 | +from hyperliquid.exchange import Exchange |
| 7 | +from hyperliquid.utils.constants import TESTNET_API_URL, MAINNET_API_URL |
| 8 | +from hyperliquid.utils.signing import get_timestamp_ms, action_hash, construct_phantom_agent, l1_payload |
| 9 | +from loguru import logger |
| 10 | + |
| 11 | + |
| 12 | +class KMSSigner: |
| 13 | + def __init__(self, key_id, aws_region_name, use_testnet): |
| 14 | + url = TESTNET_API_URL if use_testnet else MAINNET_API_URL |
| 15 | + self.oracle_publisher_exchange: Exchange = Exchange(wallet=None, base_url=url) |
| 16 | + |
| 17 | + self.key_id = key_id |
| 18 | + self.client = boto3.client("kms", region_name=aws_region_name) |
| 19 | + # Fetch public key once so we can derive address and check recovery id |
| 20 | + pub_der = self.client.get_public_key(KeyId=key_id)["PublicKey"] |
| 21 | + |
| 22 | + from cryptography.hazmat.primitives import serialization |
| 23 | + pub = serialization.load_der_public_key(pub_der) |
| 24 | + numbers = pub.public_numbers() |
| 25 | + x = numbers.x.to_bytes(32, "big") |
| 26 | + y = numbers.y.to_bytes(32, "big") |
| 27 | + uncompressed = b"\x04" + x + y |
| 28 | + self.public_key_bytes = uncompressed |
| 29 | + self.address = "0x" + keccak(uncompressed[1:])[-20:].hex() |
| 30 | + logger.info("KMSSigner address: {}", self.address) |
| 31 | + |
| 32 | + def set_oracle(self, dex, oracle_pxs, all_mark_pxs, external_perp_pxs): |
| 33 | + timestamp = get_timestamp_ms() |
| 34 | + oracle_pxs_wire = sorted(list(oracle_pxs.items())) |
| 35 | + mark_pxs_wire = [sorted(list(mark_pxs.items())) for mark_pxs in all_mark_pxs] |
| 36 | + external_perp_pxs_wire = sorted(list(external_perp_pxs.items())) |
| 37 | + action = { |
| 38 | + "type": "perpDeploy", |
| 39 | + "setOracle": { |
| 40 | + "dex": dex, |
| 41 | + "oraclePxs": oracle_pxs_wire, |
| 42 | + "markPxs": mark_pxs_wire, |
| 43 | + "externalPerpPxs": external_perp_pxs_wire, |
| 44 | + }, |
| 45 | + } |
| 46 | + signature = self.sign_l1_action( |
| 47 | + action, |
| 48 | + timestamp, |
| 49 | + self.oracle_publisher_exchange.base_url == MAINNET_API_URL, |
| 50 | + ) |
| 51 | + return self.oracle_publisher_exchange._post_action( |
| 52 | + action, |
| 53 | + signature, |
| 54 | + timestamp, |
| 55 | + ) |
| 56 | + |
| 57 | + def sign_l1_action(self, action, nonce, is_mainnet): |
| 58 | + hash = action_hash(action, vault_address=None, nonce=nonce, expires_after=None) |
| 59 | + phantom_agent = construct_phantom_agent(hash, is_mainnet) |
| 60 | + data = l1_payload(phantom_agent) |
| 61 | + structured_data = encode_typed_data(full_message=data) |
| 62 | + message_hash = _hash_eip191_message(structured_data) |
| 63 | + signed = self.sign_message(message_hash) |
| 64 | + return {"r": to_hex(signed["r"]), "s": to_hex(signed["s"]), "v": signed["v"]} |
| 65 | + |
| 66 | + def sign_message(self, message_hash: bytes): |
| 67 | + resp = self.client.sign( |
| 68 | + KeyId=self.key_id, |
| 69 | + Message=message_hash, |
| 70 | + MessageType="DIGEST", |
| 71 | + SigningAlgorithm="ECDSA_SHA_256", # required for secp256k1 |
| 72 | + ) |
| 73 | + der_sig = resp["Signature"] |
| 74 | + |
| 75 | + seq = core.Sequence.load(der_sig) |
| 76 | + r = int(seq[0].native) |
| 77 | + s = int(seq[1].native) |
| 78 | + |
| 79 | + for recovery_id in (0, 1): |
| 80 | + candidate = Signature(vrs=(recovery_id, r, s)) |
| 81 | + pubkey = candidate.recover_public_key_from_msg_hash(message_hash) |
| 82 | + if pubkey.to_bytes() == self.public_key_bytes: |
| 83 | + v = recovery_id + 27 |
| 84 | + break |
| 85 | + else: |
| 86 | + raise ValueError("Failed to determine recovery id") |
| 87 | + |
| 88 | + return { |
| 89 | + "r": r, |
| 90 | + "s": s, |
| 91 | + "v": v, |
| 92 | + "signature": Signature(vrs=(v, r, s)).to_bytes().hex(), |
| 93 | + } |
0 commit comments