@@ -92,7 +92,7 @@ func TestCheckAuth(t *testing.T) {
9292 incorrectPwd = "badpwd"
9393 )
9494
95- localhost , localhostTLS := startServer (t , uname , pwd )
95+ localhost , localhostTLS , cert := startServer (t , uname , pwd )
9696
9797 _ , portTLS , err := net .SplitHostPort (localhostTLS )
9898 if err != nil {
@@ -132,7 +132,6 @@ func TestCheckAuth(t *testing.T) {
132132 },
133133 wantErr : false ,
134134 },
135-
136135 {
137136 name : "correct credentials non-localhost" ,
138137 args : args {
@@ -170,7 +169,30 @@ func TestCheckAuth(t *testing.T) {
170169 Username : tt .args .username ,
171170 Password : tt .args .password ,
172171 }
173- if err := creds .CheckAuth (tt .args .ctx , tt .args .registry + "/someorg/someimage:sometag" , c , http .DefaultTransport ); (err != nil ) != tt .wantErr {
172+ // create trusted certificates pool and add our certificate
173+ certPool := x509 .NewCertPool ()
174+ certPool .AddCert (cert )
175+
176+ // client transport with the certificate
177+ transport := & http.Transport {
178+ TLSClientConfig : & tls.Config {
179+ RootCAs : certPool ,
180+ },
181+ }
182+
183+ dialer := & net.Dialer {}
184+
185+ transport .DialContext = func (ctx context.Context , network , addr string ) (net.Conn , error ) {
186+ h , p , err := net .SplitHostPort (addr )
187+ if err != nil {
188+ return nil , err
189+ }
190+ if h == "test.io" {
191+ h = "localhost"
192+ }
193+ return dialer .DialContext (ctx , network , net .JoinHostPort (h , p ))
194+ }
195+ if err := creds .CheckAuth (tt .args .ctx , tt .args .registry + "/someorg/someimage:sometag" , c , transport ); (err != nil ) != tt .wantErr {
174196 t .Errorf ("CheckAuth() error = %v, wantErr %v" , err , tt .wantErr )
175197 }
176198 })
@@ -179,141 +201,133 @@ func TestCheckAuth(t *testing.T) {
179201
180202func TestCheckAuthEmptyCreds (t * testing.T ) {
181203
182- localhost , _ := startServer (t , "" , "" )
204+ localhost , _ , _ := startServer (t , "" , "" )
183205 err := creds .CheckAuth (context .Background (), localhost + "/someorg/someimage:sometag" , docker.Credentials {}, http .DefaultTransport )
184206 if err != nil {
185207 t .Error (err )
186208 }
187209}
188210
189- func startServer (t * testing.T , uname , pwd string ) (addr , addrTLS string ) {
190- // TODO: this should be refactored to use OS-chosen ports so as not to
191- // fail when a user is running a function on the default port.)
192- listener , err := net .Listen ("tcp" , "localhost:0" )
193- if err != nil {
194- t .Fatal (err )
195- }
196- addr = listener .Addr ().String ()
197-
198- listenerTLS , err := net .Listen ("tcp" , "localhost:0" )
199- if err != nil {
200- t .Fatal (err )
201- }
202- addrTLS = listenerTLS .Addr ().String ()
203-
204- handler := http .HandlerFunc (func (resp http.ResponseWriter , req * http.Request ) {
205- if uname == "" || pwd == "" {
206- if req .Method == http .MethodPost {
207- resp .WriteHeader (http .StatusCreated )
208- } else {
209- resp .WriteHeader (http .StatusOK )
210- }
211- return
212- }
213- // TODO add also test for token based auth
214- resp .Header ().Add ("WWW-Authenticate" , "basic" )
215- if u , p , ok := req .BasicAuth (); ok {
216- if u == uname && p == pwd {
217- if req .Method == http .MethodPost {
218- resp .WriteHeader (http .StatusCreated )
219- } else {
220- resp .WriteHeader (http .StatusOK )
221- }
222- return
223- }
224- }
225- resp .WriteHeader (http .StatusUnauthorized )
226- })
227-
211+ // generate Certificates
212+ func generateCert (t * testing.T ) (tls.Certificate , * x509.Certificate ) {
228213 var randReader io.Reader = rand .Reader
229214
230215 caPublicKey , caPrivateKey , err := ed25519 .GenerateKey (randReader )
231216 if err != nil {
232217 t .Fatal (err )
233218 }
234219
235- ca := & x509.Certificate {
236- SerialNumber : big .NewInt (1 ),
237- Subject : pkix.Name {
238- CommonName : "localhost" ,
239- },
220+ caTemplate := & x509.Certificate {
221+ SerialNumber : big .NewInt (1 ),
222+ Subject : pkix.Name {CommonName : "localhost" },
240223 IPAddresses : []net.IP {net .IPv4 (127 , 0 , 0 , 1 ), net .IPv6loopback },
241224 DNSNames : []string {"localhost" , "test.io" },
242225 NotBefore : time .Now (),
243- NotAfter : time .Now ().AddDate (10 , 0 , 0 ),
226+ NotAfter : time .Now ().AddDate (1 , 0 , 0 ),
244227 IsCA : true ,
245228 ExtKeyUsage : []x509.ExtKeyUsage {x509 .ExtKeyUsageClientAuth , x509 .ExtKeyUsageServerAuth },
246229 ExtraExtensions : []pkix.Extension {},
247230 KeyUsage : x509 .KeyUsageDigitalSignature | x509 .KeyUsageCertSign ,
248231 BasicConstraintsValid : true ,
249232 }
250233
251- caBytes , err := x509 .CreateCertificate (randReader , ca , ca , caPublicKey , caPrivateKey )
234+ caBytes , err := x509 .CreateCertificate (randReader , caTemplate , caTemplate , caPublicKey , caPrivateKey )
252235 if err != nil {
253236 t .Fatal (err )
254237 }
255238
256- ca , err = x509 .ParseCertificate (caBytes )
239+ ca , err : = x509 .ParseCertificate (caBytes )
257240 if err != nil {
258241 t .Fatal (err )
259242 }
260243
261- cert := tls.Certificate {
244+ tls := tls.Certificate {
262245 Certificate : [][]byte {caBytes },
263246 PrivateKey : caPrivateKey ,
264247 Leaf : ca ,
265248 }
249+ return tls , ca
250+ }
266251
252+ func startServer (t * testing.T , uname , pwd string ) (addr , addrTLS string , ca * x509.Certificate ) {
253+ // create a custom handler function
254+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
255+ // no authentication required, empty creds
256+ if uname == "" || pwd == "" {
257+ if r .Method == http .MethodPost {
258+ w .WriteHeader (http .StatusCreated )
259+ } else {
260+ w .WriteHeader (http .StatusOK )
261+ }
262+ return
263+ }
264+
265+ w .Header ().Add ("WWW-Authenticate" , "basic" )
266+ if u , p , ok := r .BasicAuth (); ok {
267+ if u == uname && p == pwd {
268+ if r .Method == http .MethodPost {
269+ w .WriteHeader (http .StatusCreated )
270+ } else {
271+ w .WriteHeader (http .StatusOK )
272+ }
273+ return
274+ }
275+ }
276+ w .WriteHeader (http .StatusUnauthorized )
277+ })
278+
279+ // Setup certificates
280+ // tls Cert for the TLS server (has ca as Leaf)
281+ // x509 certificate which is its own CA for client
282+ tlsCert , ca := generateCert (t )
283+
284+ // create Server config
267285 server := http.Server {
268286 Handler : handler ,
269287 TLSConfig : & tls.Config {
270- ServerName : "localhost" ,
271- Certificates : []tls.Certificate {cert },
288+ ServerName : "localhost" ,
289+ // with the TLS certificate
290+ Certificates : []tls.Certificate {tlsCert },
272291 },
273292 }
274293
294+ // non-TLS listener
295+ listener , err := net .Listen ("tcp" , "localhost:0" )
296+ if err != nil {
297+ t .Fatal (err )
298+ }
299+
300+ // TLS listener
301+ listenerTLS , err := net .Listen ("tcp" , "localhost:0" )
302+ if err != nil {
303+ t .Fatal (err )
304+ }
305+ addr = listener .Addr ().String ()
306+ addrTLS = listenerTLS .Addr ().String ()
307+
308+ // listen for requests
275309 go func () {
276310 err := server .ServeTLS (listenerTLS , "" , "" )
277- if err != nil && ! strings . Contains ( err . Error (), "Server closed" ) {
311+ if err != nil && err != http . ErrServerClosed {
278312 panic (err )
279313 }
280314 }()
281315
282316 go func () {
283317 err := server .Serve (listener )
284- if err != nil && ! strings . Contains ( err . Error (), "Server closed" ) {
318+ if err != nil && err != http . ErrServerClosed {
285319 panic (err )
286320 }
287321 }()
288- // make the testing CA trusted by default HTTP transport/client
289- oldDefaultTransport := http .DefaultTransport
290- newDefaultTransport := http .DefaultTransport .(* http.Transport ).Clone ()
291- http .DefaultTransport = newDefaultTransport
292- caPool := x509 .NewCertPool ()
293- caPool .AddCert (ca )
294- newDefaultTransport .TLSClientConfig .RootCAs = caPool
295- dc := newDefaultTransport .DialContext
296- newDefaultTransport .DialContext = func (ctx context.Context , network , addr string ) (net.Conn , error ) {
297- h , p , err := net .SplitHostPort (addr )
298- if err != nil {
299- return nil , err
300- }
301- if h == "test.io" {
302- h = "localhost"
303- }
304- addr = net .JoinHostPort (h , p )
305- return dc (ctx , network , addr )
306- }
307-
322+ // shutdown servers at cleanup
308323 t .Cleanup (func () {
309324 err := server .Shutdown (context .Background ())
310325 if err != nil {
311326 t .Fatal (err )
312327 }
313- http .DefaultTransport = oldDefaultTransport
314328 })
315329
316- return addr , addrTLS
330+ return
317331}
318332
319333const (
0 commit comments