diff --git a/api_test.go b/api_test.go index d0598a8..79d534a 100644 --- a/api_test.go +++ b/api_test.go @@ -2,6 +2,7 @@ package jmespath import ( "encoding/json" + "strings" "testing" "github.com/jmespath/go-jmespath/internal/testify/assert" @@ -18,6 +19,24 @@ func TestValidUncompiledExpressionSearches(t *testing.T) { assert.Equal(2.0, result) } +func TestJSONNumber(t *testing.T) { + assert := assert.New(t) + var d interface{} + dec := json.NewDecoder(strings.NewReader(`{"foo": [{"baz":0}, {"baz":1}, {"baz":-2}, {"baz":3}, {"baz":-4}]}`)) + dec.UseNumber() + err := dec.Decode(&d) + assert.Nil(err) + r, err := Search("sort_by(foo, &baz)", d) + assert.Nil(err) + result, ok := r.([]interface{}) + assert.True(ok) + assert.Equal("-4", result[0].(map[string]interface{})["baz"].(json.Number).String()) + assert.Equal("-2", result[1].(map[string]interface{})["baz"].(json.Number).String()) + assert.Equal("0", result[2].(map[string]interface{})["baz"].(json.Number).String()) + assert.Equal("1", result[3].(map[string]interface{})["baz"].(json.Number).String()) + assert.Equal("3", result[4].(map[string]interface{})["baz"].(json.Number).String()) +} + func TestValidPrecompiledExpressionSearches(t *testing.T) { assert := assert.New(t) data := make(map[string]interface{}) diff --git a/functions.go b/functions.go index e9770e8..3d1899a 100644 --- a/functions.go +++ b/functions.go @@ -17,14 +17,16 @@ type jpFunction func(arguments []interface{}) (interface{}, error) type jpType string const ( - jpNumber jpType = "number" - jpString jpType = "string" - jpArray jpType = "array" - jpObject jpType = "object" - jpArrayNumber jpType = "array[number]" - jpArrayString jpType = "array[string]" - jpExpref jpType = "expref" - jpAny jpType = "any" + jpNumber jpType = "number" + jpString jpType = "string" + jpArray jpType = "array" + jpObject jpType = "object" + jpJSONNumber jpType = "jsnumber" + jpArrayNumber jpType = "array[number]" + jpArrayString jpType = "array[string]" + jpArrayJSONNumber jpType = "array[jsnumber]" + jpExpref jpType = "expref" + jpAny jpType = "any" ) type functionEntry struct { @@ -117,6 +119,47 @@ func (a *byExprFloat) Less(i, j int) bool { return ith < jth } +type byExprJSONNumber struct { + intr *treeInterpreter + node ASTNode + items []interface{} + hasError bool +} + +func (a *byExprJSONNumber) Len() int { + return len(a.items) +} +func (a *byExprJSONNumber) Swap(i, j int) { + a.items[i], a.items[j] = a.items[j], a.items[i] +} +func (a *byExprJSONNumber) Less(i, j int) bool { + first, err := a.intr.Execute(a.node, a.items[i]) + if err != nil { + a.hasError = true + // Return a dummy value. + return true + } + ith, ok := first.(json.Number) + if !ok { + a.hasError = true + return true + } + second, err := a.intr.Execute(a.node, a.items[j]) + if err != nil { + a.hasError = true + // Return a dummy value. + return true + } + jth, ok := second.(json.Number) + if !ok { + a.hasError = true + return true + } + left, _ := ith.Float64() + righ, _ := jth.Float64() + return left < righ +} + type functionCaller struct { functionTable map[string]functionEntry } @@ -142,21 +185,21 @@ func newFunctionCaller() *functionCaller { "abs": { name: "abs", arguments: []argSpec{ - {types: []jpType{jpNumber}}, + {types: []jpType{jpNumber, jpJSONNumber}}, }, handler: jpfAbs, }, "avg": { name: "avg", arguments: []argSpec{ - {types: []jpType{jpArrayNumber}}, + {types: []jpType{jpArrayNumber, jpArrayJSONNumber}}, }, handler: jpfAvg, }, "ceil": { name: "ceil", arguments: []argSpec{ - {types: []jpType{jpNumber}}, + {types: []jpType{jpNumber, jpJSONNumber}}, }, handler: jpfCeil, }, @@ -179,7 +222,7 @@ func newFunctionCaller() *functionCaller { "floor": { name: "floor", arguments: []argSpec{ - {types: []jpType{jpNumber}}, + {types: []jpType{jpNumber, jpJSONNumber}}, }, handler: jpfFloor, }, @@ -195,7 +238,7 @@ func newFunctionCaller() *functionCaller { "max": { name: "max", arguments: []argSpec{ - {types: []jpType{jpArrayNumber, jpArrayString}}, + {types: []jpType{jpArrayNumber, jpArrayString, jpArrayJSONNumber}}, }, handler: jpfMax, }, @@ -218,14 +261,14 @@ func newFunctionCaller() *functionCaller { "sum": { name: "sum", arguments: []argSpec{ - {types: []jpType{jpArrayNumber}}, + {types: []jpType{jpArrayNumber, jpArrayJSONNumber}}, }, handler: jpfSum, }, "min": { name: "min", arguments: []argSpec{ - {types: []jpType{jpArrayNumber, jpArrayString}}, + {types: []jpType{jpArrayNumber, jpArrayString, jpArrayJSONNumber}}, }, handler: jpfMin, }, @@ -262,7 +305,7 @@ func newFunctionCaller() *functionCaller { "sort": { name: "sort", arguments: []argSpec{ - {types: []jpType{jpArrayString, jpArrayNumber}}, + {types: []jpType{jpArrayString, jpArrayNumber, jpArrayJSONNumber}}, }, handler: jpfSort, }, @@ -352,6 +395,10 @@ func (a *argSpec) typeCheck(arg interface{}) error { if _, ok := arg.(float64); ok { return nil } + case jpJSONNumber: + if _, ok := arg.(json.Number); ok { + return nil + } case jpString: if _, ok := arg.(string); ok { return nil @@ -368,6 +415,10 @@ func (a *argSpec) typeCheck(arg interface{}) error { if _, ok := toArrayNum(arg); ok { return nil } + case jpArrayJSONNumber: + if _, ok := toArrayJSONNum(arg); ok { + return nil + } case jpArrayString: if _, ok := toArrayStr(arg); ok { return nil @@ -401,8 +452,18 @@ func (f *functionCaller) CallFunction(name string, arguments []interface{}, intr } func jpfAbs(arguments []interface{}) (interface{}, error) { - num := arguments[0].(float64) - return math.Abs(num), nil + switch num := arguments[0].(type) { + case float64: + return math.Abs(num), nil + case json.Number: + v, err := num.Float64() + if err != nil { + return nil, err + } + return math.Abs(v), nil + default: + return nil, errors.New("unknown type") + } } func jpfLength(arguments []interface{}) (interface{}, error) { @@ -431,13 +492,34 @@ func jpfAvg(arguments []interface{}) (interface{}, error) { length := float64(len(args)) numerator := 0.0 for _, n := range args { - numerator += n.(float64) + if a, ok := n.(float64); ok { + numerator += a + continue + } + if js, ok := n.(json.Number); ok { + a, err := js.Float64() + if err != nil { + continue + } + numerator += a + continue + } } return numerator / length, nil } func jpfCeil(arguments []interface{}) (interface{}, error) { - val := arguments[0].(float64) - return math.Ceil(val), nil + switch num := arguments[0].(type) { + case float64: + return math.Ceil(num), nil + case json.Number: + v, err := num.Float64() + if err != nil { + return nil, err + } + return math.Ceil(v), nil + default: + return nil, errors.New("unknown type") + } } func jpfContains(arguments []interface{}) (interface{}, error) { search := arguments[0] @@ -463,8 +545,18 @@ func jpfEndsWith(arguments []interface{}) (interface{}, error) { return strings.HasSuffix(search, suffix), nil } func jpfFloor(arguments []interface{}) (interface{}, error) { - val := arguments[0].(float64) - return math.Floor(val), nil + switch num := arguments[0].(type) { + case float64: + return math.Floor(num), nil + case json.Number: + v, err := num.Float64() + if err != nil { + return nil, err + } + return math.Floor(v), nil + default: + return nil, errors.New("unknown type") + } } func jpfMap(arguments []interface{}) (interface{}, error) { intr := arguments[0].(*treeInterpreter) @@ -482,7 +574,14 @@ func jpfMap(arguments []interface{}) (interface{}, error) { return mapped, nil } func jpfMax(arguments []interface{}) (interface{}, error) { - if items, ok := toArrayNum(arguments[0]); ok { + for _, f := range []func(interface{}) ([]float64, bool){ + toArrayNum, + toArrayJSONNum, + } { + items, ok := f(arguments[0]) + if !ok { + continue + } if len(items) == 0 { return nil, nil } @@ -538,6 +637,27 @@ func jpfMaxBy(arguments []interface{}) (interface{}, error) { return nil, err } switch t := start.(type) { + case json.Number: + bestVal, err := t.Float64() + if err != nil { + return nil, err + } + bestItem := arr[0] + for _, item := range arr[1:] { + result, err := intr.Execute(node, item) + if err != nil { + return nil, err + } + current, ok := result.(float64) + if !ok { + return nil, errors.New("invalid type, must be number") + } + if current > bestVal { + bestVal = current + bestItem = item + } + } + return bestItem, nil case float64: bestVal := t bestItem := arr[0] @@ -575,20 +695,36 @@ func jpfMaxBy(arguments []interface{}) (interface{}, error) { } return bestItem, nil default: - return nil, errors.New("invalid type, must be number of string") + return nil, fmt.Errorf("invalid type, must be number or string: %T", t) } } func jpfSum(arguments []interface{}) (interface{}, error) { - items, _ := toArrayNum(arguments[0]) - sum := 0.0 - for _, item := range items { - sum += item + for _, f := range []func(interface{}) ([]float64, bool){ + toArrayNum, + toArrayJSONNum, + } { + items, ok := f(arguments[0]) + if !ok { + continue + } + sum := 0.0 + for _, item := range items { + sum += item + } + return sum, nil } - return sum, nil + return nil, errors.New("unable to sum") } func jpfMin(arguments []interface{}) (interface{}, error) { - if items, ok := toArrayNum(arguments[0]); ok { + for _, f := range []func(interface{}) ([]float64, bool){ + toArrayNum, + toArrayJSONNum, + } { + items, ok := f(arguments[0]) + if !ok { + continue + } if len(items) == 0 { return nil, nil } @@ -633,7 +769,29 @@ func jpfMinBy(arguments []interface{}) (interface{}, error) { if err != nil { return nil, err } - if t, ok := start.(float64); ok { + switch t := start.(type) { + case json.Number: + bestVal, err := t.Float64() + if err != nil { + return nil, err + } + bestItem := arr[0] + for _, item := range arr[1:] { + result, err := intr.Execute(node, item) + if err != nil { + return nil, err + } + current, ok := result.(float64) + if !ok { + return nil, errors.New("invalid type, must be number") + } + if current < bestVal { + bestVal = current + bestItem = item + } + } + return bestItem, nil + case float64: bestVal := t bestItem := arr[0] for _, item := range arr[1:] { @@ -651,7 +809,7 @@ func jpfMinBy(arguments []interface{}) (interface{}, error) { } } return bestItem, nil - } else if t, ok := start.(string); ok { + case string: bestVal := t bestItem := arr[0] for _, item := range arr[1:] { @@ -669,15 +827,19 @@ func jpfMinBy(arguments []interface{}) (interface{}, error) { } } return bestItem, nil - } else { - return nil, errors.New("invalid type, must be number of string") + default: + return nil, fmt.Errorf("invalid type, must be number or string: %T", t) } } + func jpfType(arguments []interface{}) (interface{}, error) { arg := arguments[0] if _, ok := arg.(float64); ok { return "number", nil } + if _, ok := arg.(json.Number); ok { + return "jsnumber", nil + } if _, ok := arg.(string); ok { return "string", nil } @@ -712,7 +874,14 @@ func jpfValues(arguments []interface{}) (interface{}, error) { return collected, nil } func jpfSort(arguments []interface{}) (interface{}, error) { - if items, ok := toArrayNum(arguments[0]); ok { + for _, f := range []func(interface{}) ([]float64, bool){ + toArrayNum, + toArrayJSONNum, + } { + items, ok := f(arguments[0]) + if !ok { + continue + } d := sort.Float64Slice(items) sort.Stable(d) final := make([]interface{}, len(d)) @@ -745,22 +914,30 @@ func jpfSortBy(arguments []interface{}) (interface{}, error) { if err != nil { return nil, err } - if _, ok := start.(float64); ok { + switch start.(type) { + case json.Number: + sortable := &byExprJSONNumber{intr, node, arr, false} + sort.Stable(sortable) + if sortable.hasError { + return nil, errors.New("error in sort_by comparison") + } + return arr, nil + case float64: sortable := &byExprFloat{intr, node, arr, false} sort.Stable(sortable) if sortable.hasError { return nil, errors.New("error in sort_by comparison") } return arr, nil - } else if _, ok := start.(string); ok { + case string: sortable := &byExprString{intr, node, arr, false} sort.Stable(sortable) if sortable.hasError { return nil, errors.New("error in sort_by comparison") } return arr, nil - } else { - return nil, errors.New("invalid type, must be number of string") + default: + return nil, fmt.Errorf("invalid type, must be number of string: %T", start) } } func jpfJoin(arguments []interface{}) (interface{}, error) { @@ -810,6 +987,9 @@ func jpfToNumber(arguments []interface{}) (interface{}, error) { if v, ok := arg.(float64); ok { return v, nil } + if v, ok := arg.(json.Number); ok { + return v.Float64() + } if v, ok := arg.(string); ok { conv, err := strconv.ParseFloat(v, 64) if err != nil { diff --git a/util.go b/util.go index ddc1b7d..c8a8b66 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,7 @@ package jmespath import ( + "encoding/json" "errors" "reflect" ) @@ -156,6 +157,29 @@ func toArrayNum(data interface{}) ([]float64, bool) { return nil, false } +// ToArrayJSONNum converts an empty interface type to a slice of float64. +// If any element in the array cannot be converted, then nil is returned +// along with a second value of false. +func toArrayJSONNum(data interface{}) ([]float64, bool) { + // Is there a better way to do this with reflect? + if d, ok := data.([]interface{}); ok { + result := make([]float64, len(d)) + for i, el := range d { + item, ok := el.(json.Number) + if !ok { + return nil, false + } + v, err := item.Float64() + if err != nil { + return nil, false + } + result[i] = v + } + return result, true + } + return nil, false +} + // ToArrayStr converts an empty interface type to a slice of strings. // If any element in the array cannot be converted, then nil is returned // along with a second value of false. If the input data could be entirely