11package speedtest
22
33import (
4+ "context"
45 "encoding/json"
56 "errors"
67 "io/ioutil"
78 "net"
89 "net/http"
10+ "strings"
911 "sync"
1012 "time"
1113
@@ -124,28 +126,70 @@ func SpeedTest(c *cli.Context) error {
124126 // HTTP requests timeout
125127 http .DefaultClient .Timeout = time .Duration (c .Int (defs .OptionTimeout )) * time .Second
126128
127- // bind to source IP address if given
128- if src := c .String (defs .OptionSource ); src != "" {
129- // first we parse the IP to see if it's valid
130- localAddr , err := net .ResolveIPAddr ("ip" , src )
131- if err != nil {
132- log .Errorf ("Error parsing source IP: %s" , err )
133- return err
129+ forceIPv4 := c .Bool (defs .OptionIPv4 )
130+ forceIPv6 := c .Bool (defs .OptionIPv6 )
131+
132+ var network string
133+ switch {
134+ case forceIPv4 :
135+ network = "ip4"
136+ case forceIPv6 :
137+ network = "ip6"
138+ default :
139+ network = "ip"
140+ }
141+
142+ // bind to source IP address if given, or if ipv4/ipv6 is forced
143+ if src := c .String (defs .OptionSource ); src != "" || (forceIPv4 || forceIPv6 ) {
144+ var localTCPAddr * net.TCPAddr
145+ if src != "" {
146+ // first we parse the IP to see if it's valid
147+ addr , err := net .ResolveIPAddr (network , src )
148+ if err != nil {
149+ if strings .Contains (err .Error (), "no suitable address" ) {
150+ if forceIPv6 {
151+ log .Errorf ("Address %s is not a valid IPv6 address" , src )
152+ } else {
153+ log .Errorf ("Address %s is not a valid IPv4 address" , src )
154+ }
155+ } else {
156+ log .Errorf ("Error parsing source IP: %s" , err )
157+ }
158+ return err
159+ }
160+
161+ log .Debugf ("Using %s as source IP" , src )
162+ localTCPAddr = & net.TCPAddr {IP : addr .IP }
134163 }
135164
136- localTCPAddr := net.TCPAddr {IP : localAddr .IP }
165+ var dialContext func (context.Context , string , string ) (net.Conn , error )
166+ defaultDialer := & net.Dialer {
167+ Timeout : 30 * time .Second ,
168+ KeepAlive : 30 * time .Second ,
169+ }
170+
171+ if localTCPAddr != nil {
172+ defaultDialer .LocalAddr = localTCPAddr
173+ }
174+
175+ switch {
176+ case forceIPv4 :
177+ dialContext = func (ctx context.Context , network , address string ) (conn net.Conn , err error ) {
178+ return defaultDialer .DialContext (ctx , "tcp4" , address )
179+ }
180+ case forceIPv6 :
181+ dialContext = func (ctx context.Context , network , address string ) (conn net.Conn , err error ) {
182+ return defaultDialer .DialContext (ctx , "tcp6" , address )
183+ }
184+ default :
185+ dialContext = defaultDialer .DialContext
186+ }
137187
138188 // set default HTTP client's Transport to the one that binds the source address
139189 // this is modified from http.DefaultTransport
140190 transport := & http.Transport {
141- Proxy : http .ProxyFromEnvironment ,
142- DialContext : (& net.Dialer {
143- LocalAddr : & localTCPAddr ,
144- Timeout : 30 * time .Second ,
145- KeepAlive : 30 * time .Second ,
146- // although this option is marked deprecated, but it's still used in http.DefaultTransport, keeping as-is
147- DualStack : true ,
148- }).DialContext ,
191+ Proxy : http .ProxyFromEnvironment ,
192+ DialContext : dialContext ,
149193 ForceAttemptHTTP2 : true ,
150194 MaxIdleConns : 100 ,
151195 IdleConnTimeout : 90 * time .Second ,
@@ -188,7 +232,7 @@ func SpeedTest(c *cli.Context) error {
188232
189233 // if --server is given, do speed tests with all of them
190234 if len (c .IntSlice (defs .OptionServer )) > 0 {
191- return doSpeedTest (c , servers , telemetryServer , silent )
235+ return doSpeedTest (c , servers , telemetryServer , network , silent )
192236 } else {
193237 // else select the fastest server from the list
194238 log .Info ("Selecting the fastest server based on ping" )
@@ -202,7 +246,7 @@ func SpeedTest(c *cli.Context) error {
202246
203247 // spawn 10 concurrent pingers
204248 for i := 0 ; i < 10 ; i ++ {
205- go pingWorker (jobs , results , & wg , c .String (defs .OptionSource ))
249+ go pingWorker (jobs , results , & wg , c .String (defs .OptionSource ), network )
206250 }
207251
208252 // send ping jobs to workers
@@ -239,11 +283,11 @@ func SpeedTest(c *cli.Context) error {
239283 }
240284
241285 // do speed test on the server
242- return doSpeedTest (c , []defs.Server {servers [serverIdx ]}, telemetryServer , silent )
286+ return doSpeedTest (c , []defs.Server {servers [serverIdx ]}, telemetryServer , network , silent )
243287 }
244288}
245289
246- func pingWorker (jobs <- chan PingJob , results chan <- PingResult , wg * sync.WaitGroup , srcIp string ) {
290+ func pingWorker (jobs <- chan PingJob , results chan <- PingResult , wg * sync.WaitGroup , srcIp , network string ) {
247291 for {
248292 job := <- jobs
249293 server := job .Server
@@ -258,7 +302,7 @@ func pingWorker(jobs <-chan PingJob, results chan<- PingResult, wg *sync.WaitGro
258302 // check the server is up by accessing the ping URL and checking its returned value == empty and status code == 200
259303 if server .IsUp () {
260304 // if server is up, get ping
261- ping , _ , err := server .ICMPPingAndJitter (1 , srcIp )
305+ ping , _ , err := server .ICMPPingAndJitter (1 , srcIp , network )
262306 if err != nil {
263307 log .Debugf ("Can't ping server %s (%s), skipping" , server .Name , u .Hostname ())
264308 wg .Done ()
0 commit comments