@@ -4,11 +4,14 @@ package umbilical
44
55import (
66 "github.com/bwesterb/mtc"
7+ "github.com/bwesterb/mtc/umbilical/revocation"
78
9+ "bytes"
810 "crypto/tls"
911 "crypto/x509"
1012 "errors"
1113 "fmt"
14+ "slices"
1215 "strings"
1316 "sync"
1417)
@@ -83,3 +86,144 @@ func GetChainFromTLSServer(addr string) (chain []*x509.Certificate, err error) {
8386 wg .Wait ()
8487 return
8588}
89+
90+ // Checks whether the given assertion (to be) issued in the given batch
91+ // is consistent with the given X.509 certificate chain and
92+ // trusted roots. The assertion is allowed to cover less than the certificate:
93+ // eg, only example.com where the certificate covers some.example.com too.
94+ //
95+ // On the other hand, we are more strict than is perhaps required. For
96+ // instance, we do not allow an assertion for some.example.com to be backed
97+ // by a wildcard certificate for *.example.com.
98+ // Also we require basically the same chain to be valid for the full
99+ // duration of the assertion.
100+ //
101+ // If rc is set, checks whether the certificate is revoked. Does not check
102+ // revocation of intermediates.
103+ //
104+ // If consistent, returns one or more verified chains.
105+ func CheckAssertionValidForX509 (a mtc.Assertion , batch mtc.Batch ,
106+ chain []* x509.Certificate , roots * x509.CertPool , rc * revocation.Checker ) (
107+ [][]* x509.Certificate , error ) {
108+ if len (chain ) == 0 {
109+ return nil , errors .New ("empty chain" )
110+ }
111+
112+ cert := chain [0 ]
113+
114+ // Check if the claims are covered by the certificate.
115+ for _ , ip := range slices .Concat (a .Claims .IPv4 , a .Claims .IPv6 ) {
116+ ok := false
117+ for _ , ip2 := range cert .IPAddresses {
118+ if ip2 .Equal (ip ) {
119+ ok = true
120+ break
121+ }
122+ }
123+
124+ if ! ok {
125+ return nil , fmt .Errorf ("X.509 certificate not valid for %s" , ip )
126+ }
127+ }
128+
129+ got := make (map [string ]struct {})
130+ for _ , name := range cert .DNSNames {
131+ got [name ] = struct {}{}
132+ }
133+ for _ , name := range a .Claims .DNS {
134+ if _ , ok := got [name ]; ! ok {
135+ return nil , fmt .Errorf (
136+ "No exact match for %s in provided X.509 cert" ,
137+ name ,
138+ )
139+ }
140+ }
141+ for _ , name := range a .Claims .DNSWildcard {
142+ if _ , ok := got ["*." + name ]; ! ok {
143+ return nil , fmt .Errorf (
144+ "No exact match for *.%s in provided X.509 cert" ,
145+ name ,
146+ )
147+ }
148+ }
149+
150+ if len (a .Claims .Unknown ) != 0 {
151+ return nil , errors .New ("unknown claims" )
152+ }
153+
154+ // Check if subjects match.
155+ if a .Subject .Type () != mtc .TLSSubjectType {
156+ return nil , errors .New ("Expected TLSSubjectType" )
157+ }
158+ subjVerifier , err := a .Subject .(* mtc.TLSSubject ).Verifier ()
159+ if err != nil {
160+ return nil , fmt .Errorf ("Assertion Subject: %w" , err )
161+ }
162+
163+ certSubject , err := mtc .NewTLSSubject (subjVerifier .Scheme (), cert .PublicKey )
164+ if err != nil {
165+ return nil , fmt .Errorf ("NewTLSSubject(X.509 public key): %w" , err )
166+ }
167+ if ! bytes .Equal (certSubject .Info (), a .Subject .Info ()) {
168+ return nil , fmt .Errorf ("Subjects don't match" )
169+ }
170+
171+ // Verify chain at the start of the batch's validity period
172+ start , end := batch .ValidityInterval ()
173+
174+ opts := x509.VerifyOptions {
175+ Roots : roots ,
176+ Intermediates : x509 .NewCertPool (),
177+ CurrentTime : start ,
178+ }
179+ for _ , cert2 := range chain [1 :] {
180+ opts .Intermediates .AddCert (cert2 )
181+ }
182+ chains , err := cert .Verify (opts )
183+ if err != nil {
184+ return nil , fmt .Errorf ("X.509 Verify: %w" , err )
185+ }
186+
187+ var ret [][]* x509.Certificate
188+ var errs []error
189+
190+ // Verify each chain at the end of the batch's validity period
191+ for _ , candidateChain := range chains {
192+ opts = x509.VerifyOptions {
193+ Roots : x509 .NewCertPool (),
194+ Intermediates : x509 .NewCertPool (),
195+ CurrentTime : end ,
196+ }
197+
198+ for _ , cert2 := range candidateChain [1 : len (candidateChain )- 1 ] {
199+ opts .Intermediates .AddCert (cert2 )
200+ }
201+ opts .Roots .AddCert (candidateChain [len (candidateChain )- 1 ])
202+ _ , err := cert .Verify (opts )
203+ if err != nil {
204+ errs = append (errs , err )
205+ continue
206+ }
207+ ret = append (ret , candidateChain )
208+ }
209+
210+ if len (ret ) == 0 {
211+ return nil , fmt .Errorf (
212+ "Could not find chain valid during lifetime of certificate: %w" ,
213+ errors .Join (errs ... ),
214+ )
215+ }
216+
217+ if rc != nil {
218+ revoked , err := rc .Revoked (ret [0 ][0 ], ret [0 ][1 ])
219+ if err != nil {
220+ return nil , fmt .Errorf ("checking revocation: %w" , err )
221+ }
222+
223+ if revoked {
224+ return nil , errors .New ("certificate is revoked" )
225+ }
226+ }
227+
228+ return ret , nil
229+ }
0 commit comments