Skip to content

Commit 36f9adb

Browse files
authored
Refactor groupBy and sortBy (#562)
* Rearrange opcodes * Refactor groupBy * Fix tests * Super opcodes for GroupBy * Refactor sortBy builtin * Fix sortBy bench
1 parent 0bc9d99 commit 36f9adb

File tree

14 files changed

+237
-238
lines changed

14 files changed

+237
-238
lines changed

bench_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ func Benchmark_sortBy(b *testing.B) {
486486
env["arr"].([]Foo)[i] = Foo{Value: v.(int)}
487487
}
488488

489-
program, err := expr.Compile(`sortBy(arr, "Value")`, expr.Env(env))
489+
program, err := expr.Compile(`sortBy(arr, .Value)`, expr.Env(env))
490490
require.NoError(b, err)
491491

492492
var out any

builtin/builtin.go

Lines changed: 49 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ var Builtins = []*Function{
8787
Predicate: true,
8888
Types: types(new(func([]any, func(any) any) map[any][]any)),
8989
},
90+
{
91+
Name: "sortBy",
92+
Predicate: true,
93+
Types: types(new(func([]any, func(any) bool, string) []any)),
94+
},
9095
{
9196
Name: "reduce",
9297
Predicate: true,
@@ -905,109 +910,65 @@ var Builtins = []*Function{
905910
},
906911
{
907912
Name: "sort",
908-
Func: func(args ...any) (any, error) {
913+
Safe: func(args ...any) (any, uint, error) {
909914
if len(args) != 1 && len(args) != 2 {
910-
return nil, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args))
915+
return nil, 0, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args))
911916
}
912917

913-
v := reflect.ValueOf(args[0])
914-
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
915-
return nil, fmt.Errorf("cannot sort %s", v.Kind())
916-
}
918+
var array []any
917919

918-
orderBy := OrderBy{}
919-
if len(args) == 2 {
920-
dir, err := ascOrDesc(args[1])
921-
if err != nil {
922-
return nil, err
920+
switch in := args[0].(type) {
921+
case []any:
922+
array = make([]any, len(in))
923+
copy(array, in)
924+
case []int:
925+
array = make([]any, len(in))
926+
for i, v := range in {
927+
array[i] = v
928+
}
929+
case []float64:
930+
array = make([]any, len(in))
931+
for i, v := range in {
932+
array[i] = v
933+
}
934+
case []string:
935+
array = make([]any, len(in))
936+
for i, v := range in {
937+
array[i] = v
923938
}
924-
orderBy.Desc = dir
925939
}
926940

927-
sortable, err := copyArray(v, orderBy)
928-
if err != nil {
929-
return nil, err
930-
}
931-
sort.Sort(sortable)
932-
return sortable.Array, nil
933-
},
934-
Validate: func(args []reflect.Type) (reflect.Type, error) {
935-
if len(args) != 1 && len(args) != 2 {
936-
return anyType, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args))
937-
}
938-
switch kind(args[0]) {
939-
case reflect.Interface, reflect.Slice, reflect.Array:
940-
default:
941-
return anyType, fmt.Errorf("cannot sort %s", args[0])
942-
}
941+
var desc bool
943942
if len(args) == 2 {
944-
switch kind(args[1]) {
945-
case reflect.String, reflect.Interface:
943+
switch args[1].(string) {
944+
case "asc":
945+
desc = false
946+
case "desc":
947+
desc = true
946948
default:
947-
return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1])
949+
return nil, 0, fmt.Errorf("invalid order %s, expected asc or desc", args[1])
948950
}
949951
}
950-
return arrayType, nil
951-
},
952-
},
953-
{
954-
Name: "sortBy",
955-
Func: func(args ...any) (any, error) {
956-
if len(args) != 2 && len(args) != 3 {
957-
return nil, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args))
958-
}
959952

960-
v := reflect.ValueOf(args[0])
961-
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
962-
return nil, fmt.Errorf("cannot sort %s", v.Kind())
963-
}
964-
965-
orderBy := OrderBy{}
966-
967-
field, ok := args[1].(string)
968-
if !ok {
969-
return nil, fmt.Errorf("invalid argument for sort (expected string, got %s)", reflect.TypeOf(args[1]))
970-
}
971-
orderBy.Field = field
972-
973-
if len(args) == 3 {
974-
dir, err := ascOrDesc(args[2])
975-
if err != nil {
976-
return nil, err
977-
}
978-
orderBy.Desc = dir
979-
}
980-
981-
sortable, err := copyArray(v, orderBy)
982-
if err != nil {
983-
return nil, err
953+
sortable := &runtime.Sort{
954+
Desc: desc,
955+
Array: array,
984956
}
985957
sort.Sort(sortable)
986-
return sortable.Array, nil
987-
},
988-
Validate: func(args []reflect.Type) (reflect.Type, error) {
989-
if len(args) != 2 && len(args) != 3 {
990-
return anyType, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args))
991-
}
992-
switch kind(args[0]) {
993-
case reflect.Interface, reflect.Slice, reflect.Array:
994-
default:
995-
return anyType, fmt.Errorf("cannot sort %s", args[0])
996-
}
997-
switch kind(args[1]) {
998-
case reflect.String, reflect.Interface:
999-
default:
1000-
return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1])
1001-
}
1002-
if len(args) == 3 {
1003-
switch kind(args[2]) {
1004-
case reflect.String, reflect.Interface:
1005-
default:
1006-
return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1])
1007-
}
1008-
}
1009-
return arrayType, nil
958+
959+
return sortable.Array, uint(len(array)), nil
1010960
},
961+
Types: types(
962+
new(func([]any, string) []any),
963+
new(func([]int, string) []any),
964+
new(func([]float64, string) []any),
965+
new(func([]string, string) []any),
966+
967+
new(func([]any) []any),
968+
new(func([]float64) []any),
969+
new(func([]string) []any),
970+
new(func([]int) []any),
971+
),
1011972
},
1012973
bitFunc("bitand", func(x, y int) (any, error) {
1013974
return x & y, nil

builtin/builtin_test.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ func TestBuiltin_sort(t *testing.T) {
530530
{`sort(ArrayOfInt)`, []any{1, 2, 3}},
531531
{`sort(ArrayOfFloat)`, []any{1.0, 2.0, 3.0}},
532532
{`sort(ArrayOfInt, 'desc')`, []any{3, 2, 1}},
533-
{`sortBy(ArrayOfFoo, 'Value')`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}},
534-
{`sortBy([{id: "a"}, {id: "b"}], "id", "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}},
533+
{`sortBy(ArrayOfFoo, .Value)`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}},
534+
{`sortBy([{id: "a"}, {id: "b"}], .id, "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}},
535535
}
536536

537537
for _, test := range tests {
@@ -546,6 +546,20 @@ func TestBuiltin_sort(t *testing.T) {
546546
}
547547
}
548548

549+
func TestBuiltin_sort_i64(t *testing.T) {
550+
env := map[string]any{
551+
"array": []int{1, 2, 3},
552+
"i64": int64(1),
553+
}
554+
555+
program, err := expr.Compile(`sort(map(array, i64))`, expr.Env(env))
556+
require.NoError(t, err)
557+
558+
out, err := expr.Run(program, env)
559+
require.NoError(t, err)
560+
assert.Equal(t, []any{int64(1), int64(1), int64(1)}, out)
561+
}
562+
549563
func TestBuiltin_bitOpsFunc(t *testing.T) {
550564
tests := []struct {
551565
input string

builtin/sort.go

Lines changed: 0 additions & 96 deletions
This file was deleted.

checker/checker.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
633633
if isAny(collection) {
634634
return arrayType, info{}
635635
}
636-
return reflect.SliceOf(collection.Elem()), info{}
636+
return arrayType, info{}
637637
}
638638
return v.error(node.Arguments[1], "predicate should has one input and one output param")
639639

@@ -651,7 +651,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
651651
closure.NumOut() == 1 &&
652652
closure.NumIn() == 1 && isAny(closure.In(0)) {
653653

654-
return reflect.SliceOf(closure.Out(0)), info{}
654+
return arrayType, info{}
655655
}
656656
return v.error(node.Arguments[1], "predicate should has one input and one output param")
657657

@@ -739,6 +739,28 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
739739
}
740740
return v.error(node.Arguments[1], "predicate should has one input and one output param")
741741

742+
case "sortBy":
743+
collection, _ := v.visit(node.Arguments[0])
744+
if !isArray(collection) && !isAny(collection) {
745+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
746+
}
747+
748+
v.begin(collection)
749+
closure, _ := v.visit(node.Arguments[1])
750+
v.end()
751+
752+
if len(node.Arguments) == 3 {
753+
_, _ = v.visit(node.Arguments[2])
754+
}
755+
756+
if isFunc(closure) &&
757+
closure.NumOut() == 1 &&
758+
closure.NumIn() == 1 && isAny(closure.In(0)) {
759+
760+
return reflect.TypeOf([]any{}), info{}
761+
}
762+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
763+
742764
case "reduce":
743765
collection, _ := v.visit(node.Arguments[0])
744766
if !isArray(collection) && !isAny(collection) {

checker/checker_test.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,6 @@ invalid operation: < (mismatched types mock.Bar and int) (1:29)
400400
| all(ArrayOfFoo, {#.Method() < 0})
401401
| ............................^
402402
403-
map(Any, {0})[0] + "str"
404-
invalid operation: + (mismatched types int and string) (1:18)
405-
| map(Any, {0})[0] + "str"
406-
| .................^
407-
408403
Variadic()
409404
not enough arguments to call Variadic (1:1)
410405
| Variadic()
@@ -445,11 +440,6 @@ builtin map takes only array (got int) (1:5)
445440
| map(1, {2})
446441
| ....^
447442
448-
map(filter(ArrayOfFoo, {true}), {.Not})
449-
type mock.Foo has no field Not (1:35)
450-
| map(filter(ArrayOfFoo, {true}), {.Not})
451-
| ..................................^
452-
453443
ArrayOfFoo[Foo]
454444
array elements can only be selected using an integer (got mock.Foo) (1:12)
455445
| ArrayOfFoo[Foo]

0 commit comments

Comments
 (0)