Skip to content

Commit c17c80c

Browse files
authored
Merge pull request #385 from dottyjones/master
Add test for handshake deadline
2 parents 21ab95f + badcf87 commit c17c80c

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

client_server_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,65 @@ func TestDialTimeout(t *testing.T) {
344344
}
345345
}
346346

347+
// requireDeadlineNetConn fails the current test when Read or Write are called
348+
// with no deadline.
349+
type requireDeadlineNetConn struct {
350+
t *testing.T
351+
c net.Conn
352+
readDeadlineIsSet bool
353+
writeDeadlineIsSet bool
354+
}
355+
356+
func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
357+
c.writeDeadlineIsSet = !t.Equal(time.Time{})
358+
c.readDeadlineIsSet = c.writeDeadlineIsSet
359+
return c.c.SetDeadline(t)
360+
}
361+
362+
func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
363+
c.readDeadlineIsSet = !t.Equal(time.Time{})
364+
return c.c.SetDeadline(t)
365+
}
366+
367+
func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
368+
c.writeDeadlineIsSet = !t.Equal(time.Time{})
369+
return c.c.SetDeadline(t)
370+
}
371+
372+
func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
373+
if !c.writeDeadlineIsSet {
374+
c.t.Fatalf("write with no deadline")
375+
}
376+
return c.c.Write(p)
377+
}
378+
379+
func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
380+
if !c.readDeadlineIsSet {
381+
c.t.Fatalf("read with no deadline")
382+
}
383+
return c.c.Read(p)
384+
}
385+
386+
func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
387+
func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
388+
func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
389+
390+
func TestHandshakeTimeout(t *testing.T) {
391+
s := newServer(t)
392+
defer s.Close()
393+
394+
d := cstDialer
395+
d.NetDial = func(n, a string) (net.Conn, error) {
396+
c, err := net.Dial(n, a)
397+
return &requireDeadlineNetConn{c: c, t: t}, err
398+
}
399+
ws, _, err := d.Dial(s.URL, nil)
400+
if err != nil {
401+
t.Fatal("Dial:", err)
402+
}
403+
ws.Close()
404+
}
405+
347406
func TestDialBadScheme(t *testing.T) {
348407
s := newServer(t)
349408
defer s.Close()

0 commit comments

Comments
 (0)