Skip to content

Commit f90b62c

Browse files
author
dottyjones
authored
Add test for handshake deadline
1 parent 21ab95f commit f90b62c

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

client_server_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,63 @@ func TestDialTimeout(t *testing.T) {
344344
}
345345
}
346346

347+
// netConnDeadlineObserver fails test if read or write called without deadline.
348+
type netConnDeadlineObserver struct {
349+
t *testing.T
350+
c net.Conn
351+
read, write bool
352+
}
353+
354+
func (c *netConnDeadlineObserver) SetDeadline(t time.Time) error {
355+
c.write = !t.Equal(time.Time{})
356+
c.read = c.write
357+
return c.c.SetDeadline(t)
358+
}
359+
360+
func (c *netConnDeadlineObserver) SetReadDeadline(t time.Time) error {
361+
c.read = !t.Equal(time.Time{})
362+
return c.c.SetDeadline(t)
363+
}
364+
365+
func (c *netConnDeadlineObserver) SetWriteDeadline(t time.Time) error {
366+
c.write = !t.Equal(time.Time{})
367+
return c.c.SetDeadline(t)
368+
}
369+
370+
func (c *netConnDeadlineObserver) Write(p []byte) (int, error) {
371+
if !c.write {
372+
c.t.Fatalf("write with no deadline")
373+
}
374+
return c.c.Write(p)
375+
}
376+
377+
func (c *netConnDeadlineObserver) Read(p []byte) (int, error) {
378+
if !c.read {
379+
c.t.Fatalf("read with no deadline")
380+
}
381+
return c.c.Read(p)
382+
}
383+
384+
func (c *netConnDeadlineObserver) Close() error { return c.c.Close() }
385+
func (c *netConnDeadlineObserver) LocalAddr() net.Addr { return c.c.LocalAddr() }
386+
func (c *netConnDeadlineObserver) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
387+
388+
func TestHandshakeTimeout(t *testing.T) {
389+
s := newServer(t)
390+
defer s.Close()
391+
392+
d := cstDialer
393+
d.NetDial = func(n, a string) (net.Conn, error) {
394+
c, err := net.Dial(n, a)
395+
return &netConnDeadlineObserver{c: c, t: t}, err
396+
}
397+
ws, _, err := d.Dial(s.URL, nil)
398+
if err != nil {
399+
t.Fatal("Dial:", err)
400+
}
401+
ws.Close()
402+
}
403+
347404
func TestDialBadScheme(t *testing.T) {
348405
s := newServer(t)
349406
defer s.Close()

0 commit comments

Comments
 (0)