Skip to content

Commit 25f4265

Browse files
committed
fix flaky unit test
picking knative#2749 Signed-off-by: David Fridrich <[email protected]> Signed-off-by: Matej Vašek <[email protected]>
1 parent a2c72f4 commit 25f4265

File tree

1 file changed

+92
-78
lines changed

1 file changed

+92
-78
lines changed

pkg/docker/creds/credentials_test.go

Lines changed: 92 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func TestCheckAuth(t *testing.T) {
9292
incorrectPwd = "badpwd"
9393
)
9494

95-
localhost, localhostTLS := startServer(t, uname, pwd)
95+
localhost, localhostTLS, cert := startServer(t, uname, pwd)
9696

9797
_, portTLS, err := net.SplitHostPort(localhostTLS)
9898
if err != nil {
@@ -132,7 +132,6 @@ func TestCheckAuth(t *testing.T) {
132132
},
133133
wantErr: false,
134134
},
135-
136135
{
137136
name: "correct credentials non-localhost",
138137
args: args{
@@ -170,7 +169,30 @@ func TestCheckAuth(t *testing.T) {
170169
Username: tt.args.username,
171170
Password: tt.args.password,
172171
}
173-
if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, http.DefaultTransport); (err != nil) != tt.wantErr {
172+
// create trusted certificates pool and add our certificate
173+
certPool := x509.NewCertPool()
174+
certPool.AddCert(cert)
175+
176+
// client transport with the certificate
177+
transport := &http.Transport{
178+
TLSClientConfig: &tls.Config{
179+
RootCAs: certPool,
180+
},
181+
}
182+
183+
dialer := &net.Dialer{}
184+
185+
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
186+
h, p, err := net.SplitHostPort(addr)
187+
if err != nil {
188+
return nil, err
189+
}
190+
if h == "test.io" {
191+
h = "localhost"
192+
}
193+
return dialer.DialContext(ctx, network, net.JoinHostPort(h, p))
194+
}
195+
if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, transport); (err != nil) != tt.wantErr {
174196
t.Errorf("CheckAuth() error = %v, wantErr %v", err, tt.wantErr)
175197
}
176198
})
@@ -179,141 +201,133 @@ func TestCheckAuth(t *testing.T) {
179201

180202
func TestCheckAuthEmptyCreds(t *testing.T) {
181203

182-
localhost, _ := startServer(t, "", "")
204+
localhost, _, _ := startServer(t, "", "")
183205
err := creds.CheckAuth(context.Background(), localhost+"/someorg/someimage:sometag", docker.Credentials{}, http.DefaultTransport)
184206
if err != nil {
185207
t.Error(err)
186208
}
187209
}
188210

189-
func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) {
190-
// TODO: this should be refactored to use OS-chosen ports so as not to
191-
// fail when a user is running a function on the default port.)
192-
listener, err := net.Listen("tcp", "localhost:0")
193-
if err != nil {
194-
t.Fatal(err)
195-
}
196-
addr = listener.Addr().String()
197-
198-
listenerTLS, err := net.Listen("tcp", "localhost:0")
199-
if err != nil {
200-
t.Fatal(err)
201-
}
202-
addrTLS = listenerTLS.Addr().String()
203-
204-
handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
205-
if uname == "" || pwd == "" {
206-
if req.Method == http.MethodPost {
207-
resp.WriteHeader(http.StatusCreated)
208-
} else {
209-
resp.WriteHeader(http.StatusOK)
210-
}
211-
return
212-
}
213-
// TODO add also test for token based auth
214-
resp.Header().Add("WWW-Authenticate", "basic")
215-
if u, p, ok := req.BasicAuth(); ok {
216-
if u == uname && p == pwd {
217-
if req.Method == http.MethodPost {
218-
resp.WriteHeader(http.StatusCreated)
219-
} else {
220-
resp.WriteHeader(http.StatusOK)
221-
}
222-
return
223-
}
224-
}
225-
resp.WriteHeader(http.StatusUnauthorized)
226-
})
227-
211+
// generate Certificates
212+
func generateCert(t *testing.T) (tls.Certificate, *x509.Certificate) {
228213
var randReader io.Reader = rand.Reader
229214

230215
caPublicKey, caPrivateKey, err := ed25519.GenerateKey(randReader)
231216
if err != nil {
232217
t.Fatal(err)
233218
}
234219

235-
ca := &x509.Certificate{
236-
SerialNumber: big.NewInt(1),
237-
Subject: pkix.Name{
238-
CommonName: "localhost",
239-
},
220+
caTemplate := &x509.Certificate{
221+
SerialNumber: big.NewInt(1),
222+
Subject: pkix.Name{CommonName: "localhost"},
240223
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
241224
DNSNames: []string{"localhost", "test.io"},
242225
NotBefore: time.Now(),
243-
NotAfter: time.Now().AddDate(10, 0, 0),
226+
NotAfter: time.Now().AddDate(1, 0, 0),
244227
IsCA: true,
245228
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
246229
ExtraExtensions: []pkix.Extension{},
247230
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
248231
BasicConstraintsValid: true,
249232
}
250233

251-
caBytes, err := x509.CreateCertificate(randReader, ca, ca, caPublicKey, caPrivateKey)
234+
caBytes, err := x509.CreateCertificate(randReader, caTemplate, caTemplate, caPublicKey, caPrivateKey)
252235
if err != nil {
253236
t.Fatal(err)
254237
}
255238

256-
ca, err = x509.ParseCertificate(caBytes)
239+
ca, err := x509.ParseCertificate(caBytes)
257240
if err != nil {
258241
t.Fatal(err)
259242
}
260243

261-
cert := tls.Certificate{
244+
tls := tls.Certificate{
262245
Certificate: [][]byte{caBytes},
263246
PrivateKey: caPrivateKey,
264247
Leaf: ca,
265248
}
249+
return tls, ca
250+
}
266251

252+
func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string, ca *x509.Certificate) {
253+
// create a custom handler function
254+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
255+
// no authentication required, empty creds
256+
if uname == "" || pwd == "" {
257+
if r.Method == http.MethodPost {
258+
w.WriteHeader(http.StatusCreated)
259+
} else {
260+
w.WriteHeader(http.StatusOK)
261+
}
262+
return
263+
}
264+
265+
w.Header().Add("WWW-Authenticate", "basic")
266+
if u, p, ok := r.BasicAuth(); ok {
267+
if u == uname && p == pwd {
268+
if r.Method == http.MethodPost {
269+
w.WriteHeader(http.StatusCreated)
270+
} else {
271+
w.WriteHeader(http.StatusOK)
272+
}
273+
return
274+
}
275+
}
276+
w.WriteHeader(http.StatusUnauthorized)
277+
})
278+
279+
// Setup certificates
280+
// tls Cert for the TLS server (has ca as Leaf)
281+
// x509 certificate which is its own CA for client
282+
tlsCert, ca := generateCert(t)
283+
284+
// create Server config
267285
server := http.Server{
268286
Handler: handler,
269287
TLSConfig: &tls.Config{
270-
ServerName: "localhost",
271-
Certificates: []tls.Certificate{cert},
288+
ServerName: "localhost",
289+
// with the TLS certificate
290+
Certificates: []tls.Certificate{tlsCert},
272291
},
273292
}
274293

294+
// non-TLS listener
295+
listener, err := net.Listen("tcp", "localhost:0")
296+
if err != nil {
297+
t.Fatal(err)
298+
}
299+
300+
// TLS listener
301+
listenerTLS, err := net.Listen("tcp", "localhost:0")
302+
if err != nil {
303+
t.Fatal(err)
304+
}
305+
addr = listener.Addr().String()
306+
addrTLS = listenerTLS.Addr().String()
307+
308+
// listen for requests
275309
go func() {
276310
err := server.ServeTLS(listenerTLS, "", "")
277-
if err != nil && !strings.Contains(err.Error(), "Server closed") {
311+
if err != nil && err != http.ErrServerClosed {
278312
panic(err)
279313
}
280314
}()
281315

282316
go func() {
283317
err := server.Serve(listener)
284-
if err != nil && !strings.Contains(err.Error(), "Server closed") {
318+
if err != nil && err != http.ErrServerClosed {
285319
panic(err)
286320
}
287321
}()
288-
// make the testing CA trusted by default HTTP transport/client
289-
oldDefaultTransport := http.DefaultTransport
290-
newDefaultTransport := http.DefaultTransport.(*http.Transport).Clone()
291-
http.DefaultTransport = newDefaultTransport
292-
caPool := x509.NewCertPool()
293-
caPool.AddCert(ca)
294-
newDefaultTransport.TLSClientConfig.RootCAs = caPool
295-
dc := newDefaultTransport.DialContext
296-
newDefaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
297-
h, p, err := net.SplitHostPort(addr)
298-
if err != nil {
299-
return nil, err
300-
}
301-
if h == "test.io" {
302-
h = "localhost"
303-
}
304-
addr = net.JoinHostPort(h, p)
305-
return dc(ctx, network, addr)
306-
}
307-
322+
// shutdown servers at cleanup
308323
t.Cleanup(func() {
309324
err := server.Shutdown(context.Background())
310325
if err != nil {
311326
t.Fatal(err)
312327
}
313-
http.DefaultTransport = oldDefaultTransport
314328
})
315329

316-
return addr, addrTLS
330+
return
317331
}
318332

319333
const (

0 commit comments

Comments
 (0)