Skip to content

Commit 1829e72

Browse files
committed
Add countBy()
1 parent 7221fe0 commit 1829e72

File tree

9 files changed

+74
-0
lines changed

9 files changed

+74
-0
lines changed

bench_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,17 @@ func Benchmark_groupBy(b *testing.B) {
512512

513513
require.Equal(b, 6, out.([]any)[0])
514514
}
515+
516+
func Benchmark_countBy(b *testing.B) {
517+
program, err := expr.Compile(`countBy(1..100, # % 7)[6]`)
518+
require.NoError(b, err)
519+
520+
var out any
521+
b.ResetTimer()
522+
for n := 0; n < b.N; n++ {
523+
out, _ = vm.Run(program, nil)
524+
}
525+
b.StopTimer()
526+
527+
require.Equal(b, 14, out.(int))
528+
}

builtin/builtin.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,9 @@ var Builtins = []*ast.Function{
695695
Predicate: true,
696696
Types: types(new(func([]any, func(any) any) map[any][]any)),
697697
},
698+
{
699+
Name: "countBy",
700+
Predicate: true,
701+
Types: types(new(func([]any, func(any) any) map[any]int)),
702+
},
698703
}

checker/checker.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,25 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
736736
return reflect.TypeOf(map[any][]any{}), info{}
737737
}
738738
return v.error(node.Arguments[1], "predicate should has one input and one output param")
739+
740+
case "countBy":
741+
collection, _ := v.visit(node.Arguments[0])
742+
if !isArray(collection) && !isAny(collection) {
743+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
744+
}
745+
746+
v.begin(collection)
747+
closure, _ := v.visit(node.Arguments[1])
748+
v.end()
749+
750+
if isFunc(closure) &&
751+
closure.NumOut() == 1 &&
752+
closure.NumIn() == 1 && isAny(closure.In(0)) {
753+
754+
return reflect.TypeOf(map[any]int{}), info{}
755+
}
756+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
757+
739758
}
740759

741760
if id, ok := builtin.Index[node.Name]; ok {

compiler/compiler.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,18 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
824824
c.emit(OpGetGroupBy)
825825
c.emit(OpEnd)
826826
return
827+
828+
case "countBy":
829+
c.compile(node.Arguments[0])
830+
c.emit(OpBegin)
831+
c.emitLoop(func() {
832+
c.compile(node.Arguments[1])
833+
c.emit(OpCountBy)
834+
})
835+
c.emit(OpGetCountBy)
836+
c.emit(OpEnd)
837+
return
838+
827839
}
828840

829841
if id, ok := builtin.Index[node.Name]; ok {

expr_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,10 @@ func TestExpr(t *testing.T) {
10631063
`groupBy(ArrayOfFoo, .Value).foo`,
10641064
[]any{env.ArrayOfFoo[0]},
10651065
},
1066+
{
1067+
`countBy(1..9, # % 2)`,
1068+
map[any]int{0: 4, 1: 5},
1069+
},
10661070
}
10671071

10681072
for _, tt := range tests {

parser/parser.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var predicates = map[string]struct {
2929
"findLast": {2},
3030
"findLastIndex": {2},
3131
"groupBy": {2},
32+
"countBy": {2},
3233
}
3334

3435
type parser struct {

vm/opcodes.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ const (
7373
OpGetCount
7474
OpGetLen
7575
OpGetGroupBy
76+
OpGetCountBy
7677
OpPointer
7778
OpThrow
7879
OpGroupBy
80+
OpCountBy
7981
OpBegin
8082
OpEnd // This opcode must be at the end of this list.
8183
)

vm/program.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ func (program *Program) Opcodes(w io.Writer) {
292292
case OpGetGroupBy:
293293
code("OpGetGroupBy")
294294

295+
case OpGetCountBy:
296+
code("OpGetCountBy")
297+
295298
case OpPointer:
296299
code("OpPointer")
297300

@@ -304,6 +307,9 @@ func (program *Program) Opcodes(w io.Writer) {
304307
case OpGroupBy:
305308
code("OpGroupBy")
306309

310+
case OpCountBy:
311+
code("OpCountBy")
312+
307313
case OpEnd:
308314
code("OpEnd")
309315

vm/vm.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type Scope struct {
4444
Len int
4545
Count int
4646
GroupBy map[any][]any
47+
CountBy map[any]int
4748
}
4849

4950
func Debug() *VM {
@@ -465,6 +466,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
465466
case OpGetGroupBy:
466467
vm.push(vm.Scope().GroupBy)
467468

469+
case OpGetCountBy:
470+
vm.push(vm.Scope().CountBy)
471+
468472
case OpPointer:
469473
scope := vm.Scope()
470474
vm.push(scope.Array.Index(scope.Index).Interface())
@@ -481,6 +485,13 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
481485
key := vm.pop()
482486
scope.GroupBy[key] = append(scope.GroupBy[key], it)
483487

488+
case OpCountBy:
489+
scope := vm.Scope()
490+
if scope.CountBy == nil {
491+
scope.CountBy = make(map[any]int)
492+
}
493+
scope.CountBy[vm.pop()]++
494+
484495
case OpBegin:
485496
a := vm.pop()
486497
array := reflect.ValueOf(a)

0 commit comments

Comments
 (0)