@@ -6,13 +6,17 @@ import (
66 "encoding/hex"
77 "encoding/json"
88 "fmt"
9+ "math/big"
910 "strings"
1011 "testing"
1112
1213 "crypto/x509"
1314 "encoding/pem"
1415
16+ "crypto/ecdsa"
17+
1518 "github.com/Dstack-TEE/dstack/sdk/go/dstack"
19+ "github.com/ethereum/go-ethereum/crypto"
1620)
1721
1822func TestGetKey (t * testing.T ) {
@@ -406,3 +410,185 @@ func TestInfo(t *testing.T) {
406410 t .Error ("expected event log to not be empty" )
407411 }
408412}
413+
414+ func TestGetKeySignatureVerification (t * testing.T ) {
415+ expectedAppPubkey , _ := hex .DecodeString ("02b85cceca0c02d878f0ebcda72a97469a472416eb6faf3c4807642132f9786810" )
416+ expectedKmsPubkey , _ := hex .DecodeString ("02cad3a8bb11c5c0858fb3e402048b5137457039d577986daade678ed4b4ab1b9b" )
417+
418+ client := dstack .NewDstackClient ()
419+ path := "/test/path"
420+ purpose := "test-purpose"
421+ resp , err := client .GetKey (context .Background (), path , purpose )
422+ if err != nil {
423+ t .Fatal (err )
424+ }
425+
426+ if resp .Key == "" {
427+ t .Error ("expected key to not be empty" )
428+ }
429+
430+ if len (resp .SignatureChain ) != 2 {
431+ t .Fatalf ("expected signature chain to have 2 elements, got %d" , len (resp .SignatureChain ))
432+ }
433+
434+ // Extract the app signature and KMS signature from the chain
435+ appSignatureHex := resp .SignatureChain [0 ]
436+ kmsSignatureHex := resp .SignatureChain [1 ]
437+
438+ // Convert hex strings to bytes
439+ appSignature , err := hex .DecodeString (appSignatureHex )
440+ if err != nil {
441+ t .Fatalf ("failed to decode app signature: %v" , err )
442+ }
443+
444+ kmsSignature , err := hex .DecodeString (kmsSignatureHex )
445+ if err != nil {
446+ t .Fatalf ("failed to decode KMS signature: %v" , err )
447+ }
448+
449+ // Verify signatures have the correct format (signature + recovery ID)
450+ if len (appSignature ) != 65 {
451+ t .Errorf ("expected app signature to be 65 bytes (64 bytes signature + 1 byte recovery ID), got %d" , len (appSignature ))
452+ }
453+
454+ if len (kmsSignature ) != 65 {
455+ t .Errorf ("expected KMS signature to be 65 bytes (64 bytes signature + 1 byte recovery ID), got %d" , len (kmsSignature ))
456+ }
457+
458+ // Get app info to retrieve app ID for verification
459+ infoResp , err := client .Info (context .Background ())
460+ if err != nil {
461+ t .Fatal (err )
462+ }
463+
464+ // 1. Derive the public key from the private key
465+ derivedPrivKey := resp .Key
466+ derivedPubKey , err := derivePublicKeyFromPrivate (derivedPrivKey )
467+ if err != nil {
468+ t .Fatalf ("failed to derive public key: %v" , err )
469+ }
470+
471+ // 2. Construct the message that was signed
472+ message := fmt .Sprintf ("%s:%s" , purpose , hex .EncodeToString (derivedPubKey ))
473+
474+ // 3. Recover the app's public key from the signature
475+ appPubKey , err := recoverPublicKey (message , appSignature )
476+ if err != nil {
477+ t .Fatalf ("failed to recover app public key: %v" , err )
478+ }
479+
480+ // Convert the recovered public key to compressed format for comparison
481+ appPubKeyCompressed , err := compressPublicKey (appPubKey )
482+ if err != nil {
483+ t .Fatalf ("failed to compress recovered public key: %v" , err )
484+ }
485+
486+ if ! bytes .Equal (appPubKeyCompressed , expectedAppPubkey ) {
487+ t .Errorf ("app public key mismatch:\n Expected: %s\n Actual: %s" ,
488+ hex .EncodeToString (expectedAppPubkey ),
489+ hex .EncodeToString (appPubKeyCompressed ))
490+ }
491+
492+ // 4. Verify the app ID matches what we expect
493+ // The app ID should be derivable from the app's public key
494+ // or should match what's returned from the Info endpoint
495+ appIDFromInfo , err := hex .DecodeString (infoResp .AppID )
496+ if err != nil {
497+ t .Fatalf ("failed to decode app ID: %v" , err )
498+ }
499+
500+ // 5. Construct the message that KMS would have signed
501+ // This would typically be something like "dstack-kms-issued:{app_id}{app_public_key}"
502+ kmsMessage := fmt .Sprintf ("dstack-kms-issued:%s%s" , appIDFromInfo , string (appPubKeyCompressed ))
503+ kmsPubKey , err := recoverPublicKey (kmsMessage , kmsSignature )
504+ if err != nil {
505+ t .Fatalf ("failed to recover KMS public key: %v" , err )
506+ }
507+
508+ kmsPubKeyCompressed , err := compressPublicKey (kmsPubKey )
509+ if err != nil {
510+ t .Fatalf ("failed to compress KMS public key: %v" , err )
511+ }
512+
513+ if ! bytes .Equal (kmsPubKeyCompressed , expectedKmsPubkey ) {
514+ t .Errorf ("KMS public key mismatch:\n Expected: %s\n Actual: %s" ,
515+ hex .EncodeToString (expectedKmsPubkey ),
516+ hex .EncodeToString (kmsPubKeyCompressed ))
517+ }
518+
519+ // Verify that the recovered app public key can verify the app signature
520+ verified , err := verifySignature (message , appSignature , appPubKey )
521+ if err != nil {
522+ t .Fatalf ("signature verification error: %v" , err )
523+ }
524+ if ! verified {
525+ t .Error ("app signature verification failed" )
526+ }
527+ }
528+
529+ // Helper function to derive a public key from a private key
530+ func derivePublicKeyFromPrivate (privateKeyHex string ) ([]byte , error ) {
531+ privateKeyBytes , err := hex .DecodeString (privateKeyHex )
532+ if err != nil {
533+ return nil , fmt .Errorf ("failed to decode private key: %w" , err )
534+ }
535+
536+ // Import the private key
537+ privateKey , err := crypto .ToECDSA (privateKeyBytes )
538+ if err != nil {
539+ return nil , fmt .Errorf ("failed to convert to ECDSA private key: %w" , err )
540+ }
541+
542+ // Derive the public key in compressed format
543+ publicKey := crypto .CompressPubkey (& privateKey .PublicKey )
544+ return publicKey , nil
545+ }
546+
547+ // Helper function to recover a public key from a signature
548+ func recoverPublicKey (message string , signature []byte ) ([]byte , error ) {
549+ if len (signature ) != 65 {
550+ return nil , fmt .Errorf ("invalid signature length: expected 65 bytes, got %d" , len (signature ))
551+ }
552+
553+ // Hash the message using Keccak256
554+ messageHash := crypto .Keccak256 ([]byte (message ))
555+
556+ // Recover the public key
557+ pubKey , err := crypto .Ecrecover (messageHash , signature )
558+ if err != nil {
559+ return nil , fmt .Errorf ("failed to recover public key: %w" , err )
560+ }
561+
562+ return pubKey , nil
563+ }
564+
565+ // Helper function to verify a signature
566+ func verifySignature (message string , signature []byte , publicKey []byte ) (bool , error ) {
567+ if len (signature ) != 65 {
568+ return false , fmt .Errorf ("invalid signature length: expected 65 bytes, got %d" , len (signature ))
569+ }
570+
571+ // Hash the message using Keccak256
572+ messageHash := crypto .Keccak256 ([]byte (message ))
573+
574+ // The last byte is the recovery ID, we need to remove it for verification
575+ signatureWithoutRecoveryID := signature [:64 ]
576+
577+ // Verify the signature
578+ return crypto .VerifySignature (publicKey , messageHash , signatureWithoutRecoveryID ), nil
579+ }
580+
581+ // Add this helper function to compress a public key
582+ func compressPublicKey (uncompressedKey []byte ) ([]byte , error ) {
583+ if len (uncompressedKey ) < 65 || uncompressedKey [0 ] != 4 {
584+ return nil , fmt .Errorf ("invalid uncompressed public key" )
585+ }
586+ x := new (big.Int ).SetBytes (uncompressedKey [1 :33 ])
587+ y := new (big.Int ).SetBytes (uncompressedKey [33 :65 ])
588+ pubKey := & ecdsa.PublicKey {
589+ Curve : crypto .S256 (),
590+ X : x ,
591+ Y : y ,
592+ }
593+ return crypto .CompressPubkey (pubKey ), nil
594+ }
0 commit comments