|
4 | 4 | "bytes"
|
5 | 5 | "compress/flate"
|
6 | 6 | "context"
|
| 7 | + "crypto" |
| 8 | + "crypto/ecdsa" |
7 | 9 | "crypto/rsa"
|
8 | 10 | "crypto/sha256"
|
9 | 11 | "crypto/sha512"
|
@@ -68,8 +70,9 @@ type ServiceProvider struct {
|
68 | 70 | // Entity ID is optional - if not specified then MetadataURL will be used
|
69 | 71 | EntityID string
|
70 | 72 |
|
71 |
| - // Key is the RSA private key we use to sign requests. |
72 |
| - Key *rsa.PrivateKey |
| 73 | + // Key is private key we use to sign requests. It must be either an |
| 74 | + // *rsa.PrivateKey or an *ecdsa.PrivateKey. |
| 75 | + Key crypto.Signer |
73 | 76 |
|
74 | 77 | // Certificate is the RSA public part of Key.
|
75 | 78 | Certificate *x509.Certificate
|
@@ -131,7 +134,17 @@ type ServiceProvider struct {
|
131 | 134 | // to verify signatures.
|
132 | 135 | SignatureVerifier SignatureVerifier
|
133 | 136 |
|
134 |
| - // SignatureMethod, if non-empty, authentication requests will be signed |
| 137 | + // SignatureMethod, if non-empty, authentication requests will be signed. |
| 138 | + // |
| 139 | + // The method specified here must be consistent with the type of Key. |
| 140 | + // |
| 141 | + // If Key is *rsa.PrivateKey, then this must be one of dsig.RSASHA1SignatureMethod, |
| 142 | + // dsig.RSASHA256SignatureMethod, dsig.RSASHA384SignatureMethod, or |
| 143 | + // dsig.RSASHA512SignatureMethod: |
| 144 | + // |
| 145 | + // If Key is *ecdsa.PrivateKey, then this must be one of dsig.ECDSASHA1SignatureMethod, |
| 146 | + // dsig.ECDSASHA256SignatureMethod, dsig.ECDSASHA384SignatureMethod, or |
| 147 | + // dsig.ECDSASHA512SignatureMethod. |
135 | 148 | SignatureMethod string
|
136 | 149 |
|
137 | 150 | // LogoutBindings specify the bindings available for SLO endpoint. If empty,
|
@@ -548,17 +561,38 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) {
|
548 | 561 | // for _, cert := range sp.Intermediates {
|
549 | 562 | // keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
|
550 | 563 | // }
|
551 |
| - keyStore := dsig.TLSCertKeyStore(keyPair) |
552 | 564 |
|
553 |
| - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && |
554 |
| - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && |
555 |
| - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { |
| 565 | + switch sp.SignatureMethod { |
| 566 | + case dsig.RSASHA1SignatureMethod, |
| 567 | + dsig.RSASHA256SignatureMethod, |
| 568 | + dsig.RSASHA384SignatureMethod, |
| 569 | + dsig.RSASHA512SignatureMethod: |
| 570 | + if _, ok := sp.Key.(*rsa.PrivateKey); !ok { |
| 571 | + return nil, fmt.Errorf("signature method %s requires a key of type rsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key) |
| 572 | + } |
| 573 | + |
| 574 | + case dsig.ECDSASHA1SignatureMethod, |
| 575 | + dsig.ECDSASHA256SignatureMethod, |
| 576 | + dsig.ECDSASHA384SignatureMethod, |
| 577 | + dsig.ECDSASHA512SignatureMethod: |
| 578 | + if _, ok := sp.Key.(*ecdsa.PrivateKey); !ok { |
| 579 | + return nil, fmt.Errorf("signature method %s requires a key of type ecdsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key) |
| 580 | + } |
| 581 | + default: |
556 | 582 | return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
|
557 | 583 | }
|
558 |
| - signatureMethod := sp.SignatureMethod |
559 |
| - signingContext := dsig.NewDefaultSigningContext(keyStore) |
| 584 | + |
| 585 | + keyStore := dsig.TLSCertKeyStore(keyPair) |
| 586 | + chain, err := keyStore.GetChain() |
| 587 | + if err != nil { |
| 588 | + return nil, err |
| 589 | + } |
| 590 | + signingContext, err := dsig.NewSigningContext(sp.Key, chain) |
| 591 | + if err != nil { |
| 592 | + return nil, err |
| 593 | + } |
560 | 594 | signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
|
561 |
| - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { |
| 595 | + if err := signingContext.SetSignatureMethod(sp.SignatureMethod); err != nil { |
562 | 596 | return nil, err
|
563 | 597 | }
|
564 | 598 |
|
@@ -1307,31 +1341,12 @@ func (sp *ServiceProvider) validateSignature(el *etree.Element) error {
|
1307 | 1341 |
|
1308 | 1342 | // SignLogoutRequest adds the `Signature` element to the `LogoutRequest`.
|
1309 | 1343 | func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error {
|
1310 |
| - keyPair := tls.Certificate{ |
1311 |
| - Certificate: [][]byte{sp.Certificate.Raw}, |
1312 |
| - PrivateKey: sp.Key, |
1313 |
| - Leaf: sp.Certificate, |
1314 |
| - } |
1315 |
| - // TODO: add intermediates for SP |
1316 |
| - // for _, cert := range sp.Intermediates { |
1317 |
| - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) |
1318 |
| - // } |
1319 |
| - keyStore := dsig.TLSCertKeyStore(keyPair) |
1320 |
| - |
1321 |
| - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && |
1322 |
| - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && |
1323 |
| - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { |
1324 |
| - return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) |
1325 |
| - } |
1326 |
| - signatureMethod := sp.SignatureMethod |
1327 |
| - signingContext := dsig.NewDefaultSigningContext(keyStore) |
1328 |
| - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) |
1329 |
| - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { |
| 1344 | + signingContext, err := GetSigningContext(sp) |
| 1345 | + if err != nil { |
1330 | 1346 | return err
|
1331 | 1347 | }
|
1332 | 1348 |
|
1333 | 1349 | assertionEl := req.Element()
|
1334 |
| - |
1335 | 1350 | signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
|
1336 | 1351 | if err != nil {
|
1337 | 1352 | return err
|
@@ -1361,7 +1376,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ
|
1361 | 1376 | SPNameQualifier: sp.Metadata().EntityID,
|
1362 | 1377 | },
|
1363 | 1378 | }
|
1364 |
| - if len(sp.SignatureMethod) > 0 { |
| 1379 | + if sp.SignatureMethod != "" { |
1365 | 1380 | if err := sp.SignLogoutRequest(&req); err != nil {
|
1366 | 1381 | return nil, err
|
1367 | 1382 | }
|
@@ -1475,7 +1490,7 @@ func (sp *ServiceProvider) MakeLogoutResponse(idpURL, logoutRequestID string) (*
|
1475 | 1490 | },
|
1476 | 1491 | }
|
1477 | 1492 |
|
1478 |
| - if len(sp.SignatureMethod) > 0 { |
| 1493 | + if sp.SignatureMethod != "" { |
1479 | 1494 | if err := sp.SignLogoutResponse(&response); err != nil {
|
1480 | 1495 | return nil, err
|
1481 | 1496 | }
|
@@ -1572,31 +1587,12 @@ func (r *LogoutResponse) Post(relayState string) []byte {
|
1572 | 1587 |
|
1573 | 1588 | // SignLogoutResponse adds the `Signature` element to the `LogoutResponse`.
|
1574 | 1589 | func (sp *ServiceProvider) SignLogoutResponse(resp *LogoutResponse) error {
|
1575 |
| - keyPair := tls.Certificate{ |
1576 |
| - Certificate: [][]byte{sp.Certificate.Raw}, |
1577 |
| - PrivateKey: sp.Key, |
1578 |
| - Leaf: sp.Certificate, |
1579 |
| - } |
1580 |
| - // TODO: add intermediates for SP |
1581 |
| - // for _, cert := range sp.Intermediates { |
1582 |
| - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) |
1583 |
| - // } |
1584 |
| - keyStore := dsig.TLSCertKeyStore(keyPair) |
1585 |
| - |
1586 |
| - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && |
1587 |
| - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && |
1588 |
| - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { |
1589 |
| - return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) |
1590 |
| - } |
1591 |
| - signatureMethod := sp.SignatureMethod |
1592 |
| - signingContext := dsig.NewDefaultSigningContext(keyStore) |
1593 |
| - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) |
1594 |
| - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { |
| 1590 | + signingContext, err := GetSigningContext(sp) |
| 1591 | + if err != nil { |
1595 | 1592 | return err
|
1596 | 1593 | }
|
1597 | 1594 |
|
1598 | 1595 | assertionEl := resp.Element()
|
1599 |
| - |
1600 | 1596 | signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
|
1601 | 1597 | if err != nil {
|
1602 | 1598 | return err
|
|
0 commit comments