Skip to content

Commit b83c3c2

Browse files
committed
Add signature verification in go tests
1 parent 807d257 commit b83c3c2

File tree

3 files changed

+208
-1
lines changed

3 files changed

+208
-1
lines changed

sdk/go/dstack/client_test.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@ import (
66
"encoding/hex"
77
"encoding/json"
88
"fmt"
9+
"math/big"
910
"strings"
1011
"testing"
1112

1213
"crypto/x509"
1314
"encoding/pem"
1415

16+
"crypto/ecdsa"
17+
1518
"github.com/Dstack-TEE/dstack/sdk/go/dstack"
19+
"github.com/ethereum/go-ethereum/crypto"
1620
)
1721

1822
func TestGetKey(t *testing.T) {
@@ -406,3 +410,185 @@ func TestInfo(t *testing.T) {
406410
t.Error("expected event log to not be empty")
407411
}
408412
}
413+
414+
func TestGetKeySignatureVerification(t *testing.T) {
415+
expectedAppPubkey, _ := hex.DecodeString("02b85cceca0c02d878f0ebcda72a97469a472416eb6faf3c4807642132f9786810")
416+
expectedKmsPubkey, _ := hex.DecodeString("02cad3a8bb11c5c0858fb3e402048b5137457039d577986daade678ed4b4ab1b9b")
417+
418+
client := dstack.NewDstackClient()
419+
path := "/test/path"
420+
purpose := "test-purpose"
421+
resp, err := client.GetKey(context.Background(), path, purpose)
422+
if err != nil {
423+
t.Fatal(err)
424+
}
425+
426+
if resp.Key == "" {
427+
t.Error("expected key to not be empty")
428+
}
429+
430+
if len(resp.SignatureChain) != 2 {
431+
t.Fatalf("expected signature chain to have 2 elements, got %d", len(resp.SignatureChain))
432+
}
433+
434+
// Extract the app signature and KMS signature from the chain
435+
appSignatureHex := resp.SignatureChain[0]
436+
kmsSignatureHex := resp.SignatureChain[1]
437+
438+
// Convert hex strings to bytes
439+
appSignature, err := hex.DecodeString(appSignatureHex)
440+
if err != nil {
441+
t.Fatalf("failed to decode app signature: %v", err)
442+
}
443+
444+
kmsSignature, err := hex.DecodeString(kmsSignatureHex)
445+
if err != nil {
446+
t.Fatalf("failed to decode KMS signature: %v", err)
447+
}
448+
449+
// Verify signatures have the correct format (signature + recovery ID)
450+
if len(appSignature) != 65 {
451+
t.Errorf("expected app signature to be 65 bytes (64 bytes signature + 1 byte recovery ID), got %d", len(appSignature))
452+
}
453+
454+
if len(kmsSignature) != 65 {
455+
t.Errorf("expected KMS signature to be 65 bytes (64 bytes signature + 1 byte recovery ID), got %d", len(kmsSignature))
456+
}
457+
458+
// Get app info to retrieve app ID for verification
459+
infoResp, err := client.Info(context.Background())
460+
if err != nil {
461+
t.Fatal(err)
462+
}
463+
464+
// 1. Derive the public key from the private key
465+
derivedPrivKey := resp.Key
466+
derivedPubKey, err := derivePublicKeyFromPrivate(derivedPrivKey)
467+
if err != nil {
468+
t.Fatalf("failed to derive public key: %v", err)
469+
}
470+
471+
// 2. Construct the message that was signed
472+
message := fmt.Sprintf("%s:%s", purpose, hex.EncodeToString(derivedPubKey))
473+
474+
// 3. Recover the app's public key from the signature
475+
appPubKey, err := recoverPublicKey(message, appSignature)
476+
if err != nil {
477+
t.Fatalf("failed to recover app public key: %v", err)
478+
}
479+
480+
// Convert the recovered public key to compressed format for comparison
481+
appPubKeyCompressed, err := compressPublicKey(appPubKey)
482+
if err != nil {
483+
t.Fatalf("failed to compress recovered public key: %v", err)
484+
}
485+
486+
if !bytes.Equal(appPubKeyCompressed, expectedAppPubkey) {
487+
t.Errorf("app public key mismatch:\nExpected: %s\nActual: %s",
488+
hex.EncodeToString(expectedAppPubkey),
489+
hex.EncodeToString(appPubKeyCompressed))
490+
}
491+
492+
// 4. Verify the app ID matches what we expect
493+
// The app ID should be derivable from the app's public key
494+
// or should match what's returned from the Info endpoint
495+
appIDFromInfo, err := hex.DecodeString(infoResp.AppID)
496+
if err != nil {
497+
t.Fatalf("failed to decode app ID: %v", err)
498+
}
499+
500+
// 5. Construct the message that KMS would have signed
501+
// This would typically be something like "dstack-kms-issued:{app_id}{app_public_key}"
502+
kmsMessage := fmt.Sprintf("dstack-kms-issued:%s%s", appIDFromInfo, string(appPubKeyCompressed))
503+
kmsPubKey, err := recoverPublicKey(kmsMessage, kmsSignature)
504+
if err != nil {
505+
t.Fatalf("failed to recover KMS public key: %v", err)
506+
}
507+
508+
kmsPubKeyCompressed, err := compressPublicKey(kmsPubKey)
509+
if err != nil {
510+
t.Fatalf("failed to compress KMS public key: %v", err)
511+
}
512+
513+
if !bytes.Equal(kmsPubKeyCompressed, expectedKmsPubkey) {
514+
t.Errorf("KMS public key mismatch:\nExpected: %s\nActual: %s",
515+
hex.EncodeToString(expectedKmsPubkey),
516+
hex.EncodeToString(kmsPubKeyCompressed))
517+
}
518+
519+
// Verify that the recovered app public key can verify the app signature
520+
verified, err := verifySignature(message, appSignature, appPubKey)
521+
if err != nil {
522+
t.Fatalf("signature verification error: %v", err)
523+
}
524+
if !verified {
525+
t.Error("app signature verification failed")
526+
}
527+
}
528+
529+
// Helper function to derive a public key from a private key
530+
func derivePublicKeyFromPrivate(privateKeyHex string) ([]byte, error) {
531+
privateKeyBytes, err := hex.DecodeString(privateKeyHex)
532+
if err != nil {
533+
return nil, fmt.Errorf("failed to decode private key: %w", err)
534+
}
535+
536+
// Import the private key
537+
privateKey, err := crypto.ToECDSA(privateKeyBytes)
538+
if err != nil {
539+
return nil, fmt.Errorf("failed to convert to ECDSA private key: %w", err)
540+
}
541+
542+
// Derive the public key in compressed format
543+
publicKey := crypto.CompressPubkey(&privateKey.PublicKey)
544+
return publicKey, nil
545+
}
546+
547+
// Helper function to recover a public key from a signature
548+
func recoverPublicKey(message string, signature []byte) ([]byte, error) {
549+
if len(signature) != 65 {
550+
return nil, fmt.Errorf("invalid signature length: expected 65 bytes, got %d", len(signature))
551+
}
552+
553+
// Hash the message using Keccak256
554+
messageHash := crypto.Keccak256([]byte(message))
555+
556+
// Recover the public key
557+
pubKey, err := crypto.Ecrecover(messageHash, signature)
558+
if err != nil {
559+
return nil, fmt.Errorf("failed to recover public key: %w", err)
560+
}
561+
562+
return pubKey, nil
563+
}
564+
565+
// Helper function to verify a signature
566+
func verifySignature(message string, signature []byte, publicKey []byte) (bool, error) {
567+
if len(signature) != 65 {
568+
return false, fmt.Errorf("invalid signature length: expected 65 bytes, got %d", len(signature))
569+
}
570+
571+
// Hash the message using Keccak256
572+
messageHash := crypto.Keccak256([]byte(message))
573+
574+
// The last byte is the recovery ID, we need to remove it for verification
575+
signatureWithoutRecoveryID := signature[:64]
576+
577+
// Verify the signature
578+
return crypto.VerifySignature(publicKey, messageHash, signatureWithoutRecoveryID), nil
579+
}
580+
581+
// Add this helper function to compress a public key
582+
func compressPublicKey(uncompressedKey []byte) ([]byte, error) {
583+
if len(uncompressedKey) < 65 || uncompressedKey[0] != 4 {
584+
return nil, fmt.Errorf("invalid uncompressed public key")
585+
}
586+
x := new(big.Int).SetBytes(uncompressedKey[1:33])
587+
y := new(big.Int).SetBytes(uncompressedKey[33:65])
588+
pubKey := &ecdsa.PublicKey{
589+
Curve: crypto.S256(),
590+
X: x,
591+
Y: y,
592+
}
593+
return crypto.CompressPubkey(pubKey), nil
594+
}

sdk/go/go.mod

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
module github.com/Dstack-TEE/dstack/sdk/go
22

3-
go 1.21.6
3+
go 1.23.0
4+
5+
toolchain go1.23.8
6+
7+
require (
8+
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
9+
github.com/ethereum/go-ethereum v1.15.7 // indirect
10+
github.com/holiman/uint256 v1.3.2 // indirect
11+
golang.org/x/crypto v0.35.0 // indirect
12+
golang.org/x/sys v0.30.0 // indirect
13+
)

sdk/go/go.sum

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc=
2+
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc=
3+
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs=
4+
github.com/ethereum/go-ethereum v1.15.7 h1:vm1XXruZVnqtODBgqFaTclzP0xAvCvQIDKyFNUA1JpY=
5+
github.com/ethereum/go-ethereum v1.15.7/go.mod h1:+S9k+jFzlyVTNcYGvqFhzN/SFhI6vA+aOY4T5tLSPL0=
6+
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
7+
github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E=
8+
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
9+
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
10+
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
11+
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

0 commit comments

Comments
 (0)