diff --git a/gorequest.go b/gorequest.go index 8edb0fc..eb06cfc 100644 --- a/gorequest.go +++ b/gorequest.go @@ -31,6 +31,7 @@ import ( "github.com/moul/http2curl" "golang.org/x/net/publicsuffix" + //"context" ) type Request *http.Request @@ -454,10 +455,14 @@ func (s *SuperAgent) queryStruct(content interface{}) *SuperAgent { } func (s *SuperAgent) queryString(content string) *SuperAgent { - var val map[string]string + var val map[string]interface{} if err := json.Unmarshal([]byte(content), &val); err == nil { for k, v := range val { - s.QueryData.Add(k, v) + if v == nil { + s.QueryData.Add(k, "null") + } else { + s.QueryData.Add(k, fmt.Sprintf("%v", v)) + } } } else { if queryData, err := url.ParseQuery(content); err == nil { @@ -487,6 +492,16 @@ func (s *SuperAgent) Param(key string, value string) *SuperAgent { } func (s *SuperAgent) Timeout(timeout time.Duration) *SuperAgent { + //s.Transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + // conn, err := net.DialTimeout(network, addr, timeout) + // if err != nil { + // s.Errors = append(s.Errors, err) + // return nil, err + // } + // conn.SetDeadline(time.Now().Add(timeout)) + // return conn, nil + //} + s.Transport.Dial = func(network, addr string) (net.Conn, error) { conn, err := net.DialTimeout(network, addr, timeout) if err != nil { @@ -603,17 +618,17 @@ func (s *SuperAgent) RedirectPolicy(policy func(req Request, via []Request) erro // func (s *SuperAgent) Send(content interface{}) *SuperAgent { // TODO: add normal text mode or other mode to Send func - switch v := reflect.ValueOf(content); v.Kind() { + v := reflect.ValueOf(content) + vKind := getKind(v) + switch vKind { case reflect.String: s.SendString(v.String()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: // includes rune + case reflect.Int: // includes rune s.SendString(strconv.FormatInt(v.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // includes byte + case reflect.Uint: // includes byte s.SendString(strconv.FormatUint(v.Uint(), 10)) - case reflect.Float64: - s.SendString(strconv.FormatFloat(v.Float(), 'f', -1, 64)) case reflect.Float32: - s.SendString(strconv.FormatFloat(v.Float(), 'f', -1, 32)) + s.SendString(strconv.FormatFloat(v.Float(), 'f', -1, 64)) case reflect.Bool: s.SendString(strconv.FormatBool(v.Bool())) case reflect.Struct: @@ -866,72 +881,41 @@ func (s *SuperAgent) SendFile(file interface{}, args ...string) *SuperAgent { return s } +func changeMapToURLValuesLoop(urlValues *url.Values, k string, data interface{}) { + dataVal := reflect.ValueOf(data) + dataKind := getKind(dataVal) + switch dataKind { + case reflect.Bool: + urlValues.Add(k, strconv.FormatBool(dataVal.Bool())) + case reflect.String: + urlValues.Add(k, dataVal.String()) + case reflect.Int: + urlValues.Add(k, strconv.FormatInt(dataVal.Int(), 10)) + case reflect.Uint: + urlValues.Add(k, strconv.FormatUint(dataVal.Uint(), 10)) + case reflect.Float32: + urlValues.Add(k, strconv.FormatFloat(dataVal.Float(), 'f', -1, 64)) + case reflect.Ptr: + changeMapToURLValuesLoop(urlValues, k, dataVal.Elem().Interface()) + case reflect.Slice: + for i := 0; i < dataVal.Len(); i++ { + changeMapToURLValuesLoop(urlValues, k, dataVal.Index(i).Interface()) + } + case reflect.Map: // exist this case? to check & test ??? TODO + for _, mk := range dataVal.MapKeys() { + changeMapToURLValuesLoop(urlValues, k, dataVal.MapIndex(mk).Interface()) + //changeMapToURLValuesLoop(urlValues, mk.String(), dataVal.MapIndex(mk).Interface()) + } + case reflect.Struct: // exist this case? to check & test ??? TODO consider how to use fieldName tag etc. + case reflect.Invalid: // TODO + default: // TODO + } +} + func changeMapToURLValues(data map[string]interface{}) url.Values { var newUrlValues = url.Values{} for k, v := range data { - switch val := v.(type) { - case string: - newUrlValues.Add(k, val) - case bool: - newUrlValues.Add(k, strconv.FormatBool(val)) - // if a number, change to string - // json.Number used to protect against a wrong (for GoRequest) default conversion - // which always converts number to float64. - // This type is caused by using Decoder.UseNumber() - case json.Number: - newUrlValues.Add(k, string(val)) - case int: - newUrlValues.Add(k, strconv.FormatInt(int64(val), 10)) - // TODO add all other int-Types (int8, int16, ...) - case float64: - newUrlValues.Add(k, strconv.FormatFloat(float64(val), 'f', -1, 64)) - case float32: - newUrlValues.Add(k, strconv.FormatFloat(float64(val), 'f', -1, 64)) - // following slices are mostly needed for tests - case []string: - for _, element := range val { - newUrlValues.Add(k, element) - } - case []int: - for _, element := range val { - newUrlValues.Add(k, strconv.FormatInt(int64(element), 10)) - } - case []bool: - for _, element := range val { - newUrlValues.Add(k, strconv.FormatBool(element)) - } - case []float64: - for _, element := range val { - newUrlValues.Add(k, strconv.FormatFloat(float64(element), 'f', -1, 64)) - } - case []float32: - for _, element := range val { - newUrlValues.Add(k, strconv.FormatFloat(float64(element), 'f', -1, 64)) - } - // these slices are used in practice like sending a struct - case []interface{}: - - if len(val) <= 0 { - continue - } - - switch val[0].(type) { - case string: - for _, element := range val { - newUrlValues.Add(k, element.(string)) - } - case bool: - for _, element := range val { - newUrlValues.Add(k, strconv.FormatBool(element.(bool))) - } - case json.Number: - for _, element := range val { - newUrlValues.Add(k, string(element.(json.Number))) - } - } - default: - // TODO add ptr, arrays, ... - } + changeMapToURLValuesLoop(&newUrlValues, k, v) } return newUrlValues } @@ -1310,3 +1294,18 @@ func (s *SuperAgent) AsCurlCommand() (string, error) { } return cmd.String(), nil } + +func getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float32 + default: + return kind + } +} diff --git a/gorequest_test.go b/gorequest_test.go index 728083e..545f001 100644 --- a/gorequest_test.go +++ b/gorequest_test.go @@ -1361,6 +1361,7 @@ func TestQueryFunc(t *testing.T) { const case2_send_struct = "/send_struct" const case3_send_string_with_duplicates = "/send_string_with_duplicates" const case4_send_map = "/send_map" + const case5_send_string_like_map_with_multi_type_vaule = "/string_like_map_with_multi_type_vaule" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != POST { t.Errorf("Expected method %q; got %q", POST, r.Method) @@ -1391,6 +1392,11 @@ func TestQueryFunc(t *testing.T) { checkQuery(t, v, "query2", "test2") checkQuery(t, v, "query3", "3.1415926") checkQuery(t, v, "query4", "true") + case case5_send_string_like_map_with_multi_type_vaule: + checkQuery(t, v, "query1", "test1") + checkQuery(t, v, "query2", "test2") + checkQuery(t, v, "query3", "3.1415926") + checkQuery(t, v, "query4", "true") } })) defer ts.Close() @@ -1427,6 +1433,10 @@ func TestQueryFunc(t *testing.T) { "query4": true, }). End() + + New().Post(ts.URL + case5_send_string_like_map_with_multi_type_vaule). + Query(`{"query1": "test1", "query2": "test2", "query3": "3.1415926", "query4":true}`). + End() } // TODO: more tests on redirect