@@ -7,94 +7,226 @@ package client
77
88import (
99 "context"
10+ "crypto/tls"
1011 "fmt"
1112 "io"
1213 "net"
14+ "sync"
1315 "time"
1416
17+ "github.com/quic-go/quic-go"
18+ "go.osspkg.com/algorithms/control"
1519 "go.osspkg.com/errors"
1620 "go.osspkg.com/ioutils"
1721 "go.osspkg.com/network/internal"
1822)
1923
20- type (
21- Client struct {
22- Address string
23- Timeout time.Duration
24- Network string
24+ type Client struct {
25+ Network string
26+ Address string
27+ Certificate * Certificate
28+
29+ Timeout time.Duration
30+ KeepAlive time.Duration
31+
32+ MaxIdleConns int
33+ BufferSize int
34+
35+ err error
36+ config * tls.Config
37+ pool * chanPool [* connect ]
38+ sem control.Semaphore
39+ once sync.Once
40+ }
41+
42+ func (v * Client ) setup () error {
43+ v .once .Do (func () {
44+ if v .err = internal .IsPassableNetwork (v .Network ); v .err != nil {
45+ return
46+ }
47+
48+ var (
49+ err error
50+ addr fmt.Stringer
51+ )
52+
53+ if err = v .applyTLSCertificate (); err != nil {
54+ v .err = err
55+ return
56+ }
57+
58+ switch v .Network {
59+ case internal .NetTCP :
60+ addr , err = net .ResolveTCPAddr (v .Network , v .Address )
61+ case internal .NetUDP :
62+ addr , err = net .ResolveUDPAddr (v .Network , v .Address )
63+ case internal .NetQUIC :
64+ if len (v .config .NextProtos ) == 0 {
65+ v .config .NextProtos = append (v .config .NextProtos , "quic" )
66+ }
67+ addr , err = net .ResolveUDPAddr (internal .NetUDP , v .Address )
68+ case internal .NetUNIX :
69+ addr , err = net .ResolveUnixAddr (v .Network , v .Address )
70+ default :
71+ addr , err = nil , fmt .Errorf ("invalid network name, use: tcp, udp, unix, quic" )
72+ }
73+ if err != nil {
74+ v .err = err
75+ return
76+ }
77+
78+ v .BufferSize = internal .NotZero [int ](v .BufferSize , 65535 )
79+ v .Timeout = internal .NotZeroDuration (v .Timeout , 1 * time .Second )
80+ v .KeepAlive = internal .NotZeroDuration (v .KeepAlive , 15 * time .Second )
81+ v .MaxIdleConns = internal .NotZero [int ](v .MaxIdleConns , 1 )
82+
83+ v .Address = addr .String ()
84+ v .sem = control .NewSemaphore (uint64 (v .MaxIdleConns ))
85+
86+ v .pool = newChanPool [* connect ](v .MaxIdleConns , func () * connect {
87+ pIdleAt := time .Now ().Add (v .KeepAlive )
88+ pconn , pclose , perr := v .dialConnect (context .Background ())
89+ return & connect {
90+ Conn : pconn ,
91+ CloseFunc : pclose ,
92+ Err : perr ,
93+ IdleAt : pIdleAt ,
94+ }
95+ })
96+ })
97+
98+ if v .err != nil {
99+ return v .err
25100 }
26- )
101+ return nil
102+ }
27103
28- func (v * Client ) dialConnect (ctx context.Context ) (net.Conn , error ) {
29- var (
30- addr fmt.Stringer
31- err error
32- )
33- if err = internal .IsPassableNetwork (v .Network ); err != nil {
34- return nil , err
104+ func (v * Client ) applyTLSCertificate () error {
105+ if v .Certificate == nil {
106+ return nil
35107 }
36- switch v .Network {
37- case "tcp" :
38- addr , err = net .ResolveTCPAddr ("tcp" , v .Address )
39- case "udp" :
40- addr , err = net .ResolveUDPAddr ("udp" , v .Address )
41- case "unix" :
42- addr , err = net .ResolveUnixAddr ("udp" , v .Address )
43- default :
44- return nil , fmt .Errorf ("invalid network name, use: tcp, udp, unix" )
108+ if v .config == nil {
109+ v .config = internal .DefaultTLSConfig ()
45110 }
111+ cert , ca , err := parseCertificate (* v .Certificate )
46112 if err != nil {
47- return nil , fmt .Errorf ("invalid address" )
113+ return err
114+ }
115+ if ca != nil {
116+ v .config .RootCAs = ca
117+ }
118+ if len (cert .Certificate ) >= 0 {
119+ v .config .Certificates = append (v .config .Certificates , cert )
120+ }
121+ v .config .InsecureSkipVerify = v .Certificate .InsecureSkipVerify
122+
123+ return nil
124+ }
125+
126+ func (v * Client ) dialConnect (ctx context.Context ) (action , func (), error ) {
127+ if v .Network == internal .NetQUIC {
128+ conn , err := quic .DialAddr (ctx , v .Address , v .config , & quic.Config {EnableDatagrams : true })
129+ if err != nil {
130+ return nil , nil , fmt .Errorf ("create connect: %w" , err )
131+ }
132+
133+ stream , err := conn .OpenStream ()
134+ if err != nil {
135+ writeLog (conn .CloseWithError (0 , "" ), "close connect" , v .Network , v .Address )
136+ return nil , nil , fmt .Errorf ("open stream: %w" , err )
137+ }
138+
139+ return stream , func () {
140+ writeLog (stream .Close (), "close stream" , v .Network , v .Address )
141+ writeLog (conn .CloseWithError (0 , "" ), "close connect" , v .Network , v .Address )
142+ }, nil
143+ }
144+
145+ if v .Certificate != nil {
146+ dial := & tls.Dialer {
147+ NetDialer : new (net.Dialer ),
148+ Config : v .config ,
149+ }
150+ conn , err := dial .DialContext (ctx , v .Network , v .Address )
151+ if err != nil {
152+ return nil , nil , fmt .Errorf ("create connect: %w" , err )
153+ }
154+ return conn , func () {
155+ writeLog (conn .Close (), "close connect" , v .Network , v .Address )
156+ }, nil
48157 }
49158
50159 var dial net.Dialer
51- conn , err := dial .DialContext (ctx , v .Network , addr . String () )
160+ conn , err := dial .DialContext (ctx , v .Network , v . Address )
52161 if err != nil {
53- return nil , fmt .Errorf ("create connect: %w" , err )
162+ return nil , nil , fmt .Errorf ("create connect: %w" , err )
54163 }
55- return conn , nil
164+ return conn , func () {
165+ writeLog (conn .Close (), "close connect" , v .Network , v .Address )
166+ }, nil
56167}
57168
58- func (v * Client ) Do (ctx context.Context , in io.Reader , out io.Writer ) error {
59- conn , err := v .dialConnect (ctx )
60- if err != nil {
61- return err
169+ func (v * Client ) Do (in io.Reader , out io.Writer ) (err error ) {
170+ if err = v .setup (); err != nil {
171+ return
62172 }
63173
64- defer func () {
65- conn .Close () // nolint: errcheck
66- }()
174+ v .sem .Acquire ()
175+ defer func () { v .sem .Release () }()
67176
68- ttl := internal . NotZeroDuration ( v . Timeout , 2 * time . Second )
177+ conn := v . pool . GetIdleOrCreateConn ( )
69178
70- t := time .Now ().Add (ttl )
71- err = errors .Wrap (conn .SetDeadline (t ), conn .SetReadDeadline (t ), conn .SetWriteDeadline (t ))
72- if err != nil {
73- return err
179+ if err = conn .GetError (); err != nil {
180+ return
74181 }
75182
76- ctx , cancel := context .WithTimeout (ctx , ttl )
77- defer cancel ()
183+ defer func () {
184+ conn .Err = errors .Wrap (conn .Err , err )
185+ v .pool .PutOrCloseIdleConn (conn )
186+ }()
78187
79- if _ , err = ioutils .Copy (conn , in ); err != nil {
80- return fmt .Errorf ("send message: %w" , err )
81- }
188+ errC := make (chan error , 1 )
189+ startC := make (chan struct {})
82190
83- rdC := make (chan interface {}, 1 )
84191 go func () {
85- defer close (rdC )
86- if _ , e := ioutils .Copy (out , conn ); e != nil {
87- rdC <- fmt .Errorf ("read message: %w" , e )
192+ close (startC )
193+
194+ if e := internal .Deadline (conn .Conn , v .Timeout * 2 ); e != nil {
195+ errC <- fmt .Errorf ("update deadline: %w" , e )
196+ return
88197 }
89- }()
90198
91- select {
92- case <- ctx .Done ():
93- return fmt .Errorf ("closed by timeout" )
94- case rcV := <- rdC :
95- if e , ok := rcV .(error ); ok {
96- return e
199+ n , e := ioutils .CopyPack (out , conn .Conn , v .BufferSize )
200+ if e != nil {
201+ errC <- fmt .Errorf ("read message: %w" , e )
202+ return
203+ } else if n == 0 {
204+ errC <- fmt .Errorf ("read message: got 0 bytes" )
205+ return
97206 }
98- return nil
207+
208+ errC <- nil
209+ }()
210+
211+ <- startC
212+
213+ n , e := ioutils .CopyPack (conn .Conn , in , v .BufferSize )
214+ if e != nil {
215+ err = fmt .Errorf ("write message: %w" , e )
216+ return
217+ } else if n == 0 {
218+ err = fmt .Errorf ("write message: set 0 bytes" )
219+ return
220+ }
221+
222+ if err = <- errC ; err != nil {
223+ return
99224 }
225+
226+ if err = conn .Conn .SetWriteDeadline (time .Now ().Add (v .KeepAlive )); err != nil {
227+ err = fmt .Errorf ("update deadline: %w" , err )
228+ return
229+ }
230+
231+ return
100232}
0 commit comments