diff --git a/jwk/convert.go b/jwk/convert.go index 244a7977d..3b70771c3 100644 --- a/jwk/convert.go +++ b/jwk/convert.go @@ -4,13 +4,16 @@ import ( "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" + "crypto/elliptic" "crypto/rsa" "errors" "fmt" + "math/big" "reflect" "sync" "github.com/lestrrat-go/blackmagic" + "github.com/lestrrat-go/jwx/v3/internal/ecutil" "github.com/lestrrat-go/jwx/v3/jwa" ) @@ -119,20 +122,113 @@ func init() { } { f := KeyImportFunc(okpPrivateKeyToJWK) - for _, k := range []interface{}{ed25519.PrivateKey(nil), ecdh.PrivateKey{}, &ecdh.PrivateKey{}} { + for _, k := range []interface{}{ed25519.PrivateKey(nil)} { + RegisterKeyImporter(k, f) + } + } + { + f := KeyImportFunc(ecdhPrivateKeyToJWK) + for _, k := range []interface{}{ecdh.PrivateKey{}, &ecdh.PrivateKey{}} { RegisterKeyImporter(k, f) } } { f := KeyImportFunc(okpPublicKeyToJWK) - for _, k := range []interface{}{ed25519.PublicKey(nil), ecdh.PublicKey{}, &ecdh.PublicKey{}} { + for _, k := range []interface{}{ed25519.PublicKey(nil)} { + RegisterKeyImporter(k, f) + } + } + { + f := KeyImportFunc(ecdhPublicKeyToJWK) + for _, k := range []interface{}{ecdh.PublicKey{}, &ecdh.PublicKey{}} { RegisterKeyImporter(k, f) } } - RegisterKeyImporter([]byte(nil), KeyImportFunc(bytesToKey)) } +func ecdhPrivateKeyToJWK(src interface{}) (Key, error) { + var raw *ecdh.PrivateKey + switch src := src.(type) { + case *ecdh.PrivateKey: + raw = src + case ecdh.PrivateKey: + raw = &src + default: + return nil, fmt.Errorf(`cannot convert key type '%T' to OKP jwk.Key`, src) + } + + switch raw.Curve() { + case ecdh.X25519(): + return okpPrivateKeyToJWK(raw) + case ecdh.P256(): + return ecdhPrivateKeyToECJWK(raw, elliptic.P256()) + case ecdh.P384(): + return ecdhPrivateKeyToECJWK(raw, elliptic.P384()) + case ecdh.P521(): + return ecdhPrivateKeyToECJWK(raw, elliptic.P521()) + default: + return nil, fmt.Errorf(`unsupported curve %s`, raw.Curve()) + } +} + +func ecdhPrivateKeyToECJWK(raw *ecdh.PrivateKey, crv elliptic.Curve) (Key, error) { + pub := raw.PublicKey() + rawpub := pub.Bytes() + + size := ecutil.CalculateKeySize(crv) + var x, y, d big.Int + x.SetBytes(rawpub[1 : 1+size]) + y.SetBytes(rawpub[1+size:]) + d.SetBytes(raw.Bytes()) + + var ecdsaPriv ecdsa.PrivateKey + ecdsaPriv.Curve = crv + ecdsaPriv.D = &d + ecdsaPriv.X = &x + ecdsaPriv.Y = &y + return ecdsaPrivateKeyToJWK(&ecdsaPriv) +} + +func ecdhPublicKeyToJWK(src interface{}) (Key, error) { + var raw *ecdh.PublicKey + switch src := src.(type) { + case *ecdh.PublicKey: + raw = src + case ecdh.PublicKey: + raw = &src + default: + return nil, fmt.Errorf(`cannot convert key type '%T' to OKP jwk.Key`, src) + } + + switch raw.Curve() { + case ecdh.X25519(): + return okpPublicKeyToJWK(raw) + case ecdh.P256(): + return ecdhPublicKeyToECJWK(raw, elliptic.P256()) + case ecdh.P384(): + return ecdhPublicKeyToECJWK(raw, elliptic.P384()) + case ecdh.P521(): + return ecdhPublicKeyToECJWK(raw, elliptic.P521()) + default: + return nil, fmt.Errorf(`unsupported curve %s`, raw.Curve()) + } +} + +func ecdhPublicKeyToECJWK(raw *ecdh.PublicKey, crv elliptic.Curve) (Key, error) { + rawbytes := raw.Bytes() + size := ecutil.CalculateKeySize(crv) + var x, y big.Int + + x.SetBytes(rawbytes[1 : 1+size]) + y.SetBytes(rawbytes[1+size:]) + var ecdsaPriv ecdsa.PublicKey + ecdsaPriv.Curve = crv + ecdsaPriv.X = &x + ecdsaPriv.Y = &y + return ecdsaPublicKeyToJWK(&ecdsaPriv) +} + // These may seem a bit repetitive and redandunt, but the problem is that // each key type has its own Import method -- for example, Import(*ecdsa.PrivateKey) // vs Import(*rsa.PrivateKey), and therefore they can't just be bundled into @@ -277,21 +373,21 @@ func Export(key Key, dst interface{}) error { muKeyExporters.RLock() exporters, ok := keyExporters[key.KeyType()] muKeyExporters.RUnlock() - if ok { - for _, conv := range exporters { - v, err := conv.Export(key, dst) - if err != nil { - if errors.Is(err, ContinueError()) { - continue - } - return fmt.Errorf(`jwk.Export: failed to export jwk.Key to raw format: %w`, err) - } - - if err := blackmagic.AssignIfCompatible(dst, v); err != nil { - return fmt.Errorf(`jwk.Export: failed to assign key: %w`, err) + if !ok { + return fmt.Errorf(`jwk.Export: no exporters registered for key type '%T'`, key) + } + for _, conv := range exporters { + v, err := conv.Export(key, dst) + if err != nil { + if errors.Is(err, ContinueError()) { + continue } - return nil + return fmt.Errorf(`jwk.Export: failed to export jwk.Key to raw format: %w`, err) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`jwk.Export: failed to assign key: %w`, err) } + return nil } - return fmt.Errorf(`jwk.Export: failed to find exporter for key type '%T'`, key) + return fmt.Errorf(`jwk.Export: no suitable exporter found for key type '%T'`, key) } diff --git a/jwk/ecdsa.go b/jwk/ecdsa.go index 9b4027a21..8507cc800 100644 --- a/jwk/ecdsa.go +++ b/jwk/ecdsa.go @@ -2,10 +2,12 @@ package jwk import ( "crypto" + "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "fmt" "math/big" + "reflect" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/ecutil" @@ -102,13 +104,59 @@ func buildECDSAPublicKey(alg jwa.EllipticCurveAlgorithm, xbuf, ybuf []byte) (*ec return &ecdsa.PublicKey{Curve: crv, X: &x, Y: &y}, nil } +func buildECDHPublicKey(alg jwa.EllipticCurveAlgorithm, xbuf, ybuf []byte) (*ecdh.PublicKey, error) { + var ecdhcrv ecdh.Curve + switch alg { + case jwa.X25519(): + ecdhcrv = ecdh.X25519() + case jwa.P256(): + ecdhcrv = ecdh.P256() + case jwa.P384(): + ecdhcrv = ecdh.P384() + case jwa.P521(): + ecdhcrv = ecdh.P521() + default: + return nil, fmt.Errorf(`jwk: unsupported ECDH curve %s`, alg) + } + + return ecdhcrv.NewPublicKey(append([]byte{0x04}, append(xbuf, ybuf...)...)) +} + +func buildECDHPrivateKey(alg jwa.EllipticCurveAlgorithm, dbuf []byte) (*ecdh.PrivateKey, error) { + var ecdhcrv ecdh.Curve + switch alg { + case jwa.X25519(): + ecdhcrv = ecdh.X25519() + case jwa.P256(): + ecdhcrv = ecdh.P256() + case jwa.P384(): + ecdhcrv = ecdh.P384() + case jwa.P521(): + ecdhcrv = ecdh.P521() + default: + return nil, fmt.Errorf(`jwk: unsupported ECDH curve %s`, alg) + } + + return ecdhcrv.NewPrivateKey(dbuf) +} + func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { + var isECDH bool switch k := keyif.(type) { case *ecdsaPublicKey: switch hint.(type) { - case ecdsa.PublicKey, *ecdsa.PublicKey, interface{}: + case ecdsa.PublicKey, *ecdsa.PublicKey: + case ecdh.PublicKey, *ecdh.PublicKey: + isECDH = true default: - return nil, fmt.Errorf(`invalid destination object type %T: %w`, hint, ContinueError()) + rv := reflect.ValueOf(hint) + //nolint:revive + if rv.Kind() == reflect.Ptr && rv.Elem().Kind() == reflect.Interface { + // pointer to an interface value, presumably they want us to dynamically + // create an object of the right type + } else { + return nil, fmt.Errorf(`invalid destination object type %T: %w`, hint, ContinueError()) + } } k.mu.RLock() @@ -118,12 +166,26 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { if !ok { return nil, fmt.Errorf(`missing "crv" field`) } - return buildECDSAPublicKey(crv, k.x, k.y) + + if isECDH { + return buildECDHPublicKey(crv, k.x, k.y) + } else { + return buildECDSAPublicKey(crv, k.x, k.y) + } case *ecdsaPrivateKey: switch hint.(type) { - case ecdsa.PrivateKey, *ecdsa.PrivateKey, interface{}: + case ecdsa.PrivateKey, *ecdsa.PrivateKey: + case ecdh.PrivateKey, *ecdh.PrivateKey: + isECDH = true default: - return nil, fmt.Errorf(`invalid destination object type %T: %w`, hint, ContinueError()) + rv := reflect.ValueOf(hint) + //nolint:revive + if rv.Kind() == reflect.Ptr && rv.Elem().Kind() == reflect.Interface { + // pointer to an interface value, presumably they want us to dynamically + // create an object of the right type + } else { + return nil, fmt.Errorf(`invalid destination object type %T: %w`, hint, ContinueError()) + } } k.mu.RLock() @@ -133,6 +195,10 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { if !ok { return nil, fmt.Errorf(`missing "crv" field`) } + + if isECDH { + return buildECDHPrivateKey(crv, k.d) + } pubk, err := buildECDSAPublicKey(crv, k.x, k.y) if err != nil { return nil, fmt.Errorf(`failed to build public key: %w`, err) diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 318fc827d..57286c447 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -1997,3 +1997,44 @@ func TestParse_fail(t *testing.T) { }) }) } + +func TestGH1262(t *testing.T) { + t.Run("Updated Example test", func(t *testing.T) { + keyCli, err := ecdh.P384().GenerateKey(rand.Reader) + require.NoError(t, err, `ecdh.P384().GenerateKey should succeed`) + + jwkCliPriv, err := jwk.Import(keyCli) + require.NoError(t, err, `jwk.Import should succeed`) + _ = jwkCliPriv + + var rawCliPriv ecdh.PrivateKey + require.NoError(t, jwk.Export(jwkCliPriv, &rawCliPriv), `jwk.Export should succeed`) + + pubCli := keyCli.PublicKey() // server is able to retrieve the pub key part of client + + keySrv, err := ecdh.P384().GenerateKey(rand.Reader) + require.NoError(t, err, `ecdh.P384().GenerateKey should succeed`) + + jwkSrv, err := jwk.Import(keySrv.PublicKey()) + require.NoError(t, err, `jwk.Import should succeed`) + jwkBuf, err := json.Marshal(jwkSrv) + + require.NoError(t, err, `json.Marshal should succeed`) + + secretSrv, err := keySrv.ECDH(pubCli) + require.NoError(t, err, `keySrv.ECDH should succeed`) + + _ = secretSrv // doing some non-standard encryption & response with encrypted data + + // client + pubSrv := &ecdh.PublicKey{} + jwkCli, err := jwk.ParseKey(jwkBuf) // extract jwkBuf + require.NoError(t, err, `jwk.ParseKey should succeed`) + + require.NoError(t, jwk.Export(jwkCli, pubSrv), `jwk.Export should succeed`) + secretCli, err := keyCli.ECDH(pubSrv) + require.NoError(t, err, `keyCli.ECDH should succeed`) + + _ = secretCli // doing some non-standard encryption + }) +}