Skip to content

Commit 84ac0b8

Browse files
committed
Add OpCallSafe
1 parent e53cefe commit 84ac0b8

File tree

7 files changed

+41
-32
lines changed

7 files changed

+41
-32
lines changed

builtin/builtin.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,16 @@ var Builtins = []*Function{
316316
},
317317
{
318318
Name: "repeat",
319-
ValidateArgs: func(args ...any) (any, error) {
319+
Safe: func(args ...any) (any, uint, error) {
320+
s := args[0].(string)
320321
n := runtime.ToInt(args[1])
321322
if n < 0 {
322-
panic(fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n))
323+
return nil, 0, fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n)
323324
}
324-
return uint(n), nil
325-
},
326-
Func: func(args ...any) (any, error) {
327-
return strings.Repeat(args[0].(string), runtime.ToInt(args[1])), nil
325+
if n > 1e6 {
326+
return nil, 0, fmt.Errorf("memory budget exceeded")
327+
}
328+
return strings.Repeat(s, n), uint(len(s) * n), nil
328329
},
329330
Types: types(strings.Repeat),
330331
},

builtin/function.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import (
55
)
66

77
type Function struct {
8-
Name string
9-
Func func(args ...any) (any, error)
10-
Fast func(arg any) any
11-
ValidateArgs func(args ...any) (any, error)
12-
Types []reflect.Type
13-
Validate func(args []reflect.Type) (reflect.Type, error)
14-
Predicate bool
8+
Name string
9+
Fast func(arg any) any
10+
Func func(args ...any) (any, error)
11+
Safe func(args ...any) (any, uint, error)
12+
Types []reflect.Type
13+
Validate func(args []reflect.Type) (reflect.Type, error)
14+
Predicate bool
1515
}
1616

1717
func (f *Function) Type() reflect.Type {

compiler/compiler.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func (c *compiler) addConstant(constant any) int {
116116
indexable := true
117117
hash := constant
118118
switch reflect.TypeOf(constant).Kind() {
119-
case reflect.Slice, reflect.Map, reflect.Struct:
119+
case reflect.Slice, reflect.Map, reflect.Struct, reflect.Func:
120120
indexable = false
121121
}
122122
if field, ok := constant.(*runtime.Field); ok {
@@ -908,13 +908,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
908908
c.compile(arg)
909909
}
910910

911-
if f.ValidateArgs != nil {
912-
c.emit(OpLoadFunc, c.addFunction("$_validate_args_"+f.Name, f.ValidateArgs))
913-
c.emit(OpValidateArgs, len(node.Arguments))
914-
}
915-
916911
if f.Fast != nil {
917912
c.emit(OpCallBuiltin1, id)
913+
} else if f.Safe != nil {
914+
c.emit(OpPush, c.addConstant(f.Safe))
915+
c.emit(OpCallSafe, len(node.Arguments))
918916
} else if f.Func != nil {
919917
c.emitFunction(f, len(node.Arguments))
920918
}

vm/opcodes.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ const (
5858
OpCall3
5959
OpCallN
6060
OpCallFast
61+
OpCallSafe
6162
OpCallTyped
6263
OpCallBuiltin1
63-
OpValidateArgs
6464
OpArray
6565
OpMap
6666
OpLen

vm/program.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,16 @@ func (program *Program) DisassembleWriter(w io.Writer) {
290290
case OpCallFast:
291291
argument("OpCallFast")
292292

293+
case OpCallSafe:
294+
argument("OpCallSafe")
295+
293296
case OpCallTyped:
294297
signature := reflect.TypeOf(FuncTypes[arg]).Elem().String()
295298
_, _ = fmt.Fprintf(w, "%v\t%v\t<%v>\t%v\n", pp, "OpCallTyped", arg, signature)
296299

297300
case OpCallBuiltin1:
298301
builtinArg("OpCallBuiltin1")
299302

300-
case OpValidateArgs:
301-
argument("OpValidateArgs")
302-
303303
case OpArray:
304304
code("OpArray")
305305

vm/utils.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ import (
44
"reflect"
55
)
66

7-
type Function = func(params ...any) (any, error)
7+
type (
8+
Function = func(params ...any) (any, error)
9+
SafeFunction = func(params ...any) (any, uint, error)
10+
)
811

12+
// MemoryBudget represents an upper limit of memory usage.
913
var MemoryBudget uint = 1e6
1014

1115
var errorType = reflect.TypeOf((*error)(nil)).Elem()

vm/vm.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -389,20 +389,26 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
389389
}
390390
vm.push(fn(in...))
391391

392+
case OpCallSafe:
393+
fn := vm.pop().(SafeFunction)
394+
size := arg
395+
in := make([]any, size)
396+
for i := int(size) - 1; i >= 0; i-- {
397+
in[i] = vm.pop()
398+
}
399+
out, mem, err := fn(in...)
400+
if err != nil {
401+
panic(err)
402+
}
403+
vm.memGrow(mem)
404+
vm.push(out)
405+
392406
case OpCallTyped:
393407
vm.push(vm.call(vm.pop(), arg))
394408

395409
case OpCallBuiltin1:
396410
vm.push(builtin.Builtins[arg].Fast(vm.pop()))
397411

398-
case OpValidateArgs:
399-
fn := vm.pop().(Function)
400-
mem, err := fn(vm.Stack[len(vm.Stack)-arg:]...)
401-
if err != nil {
402-
panic(err)
403-
}
404-
vm.memGrow(mem.(uint))
405-
406412
case OpArray:
407413
size := vm.pop().(int)
408414
vm.memGrow(uint(size))

0 commit comments

Comments
 (0)