Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ test_control_flow_w_jlox: build
]" \
$(shell pwd)/dist/main.out


test_functions_w_jlox: build
CODECRAFTERS_REPOSITORY_DIR=./craftinginterpreters/build/gen/chap10_functions \
CODECRAFTERS_TEST_CASES_JSON="[ \
Expand All @@ -154,7 +153,7 @@ test_functions_w_jlox: build
$(shell pwd)/dist/main.out

test_resolving_w_jlox: build
CODECRAFTERS_REPOSITORY_DIR=./craftinginterpreters/build/gen/chap10_functions \
CODECRAFTERS_REPOSITORY_DIR=./craftinginterpreters/build/gen/chap11_resolving \
CODECRAFTERS_TEST_CASES_JSON="[ \
{\"slug\":\"r1\",\"tester_log_prefix\":\"stage_701\",\"title\":\"Stage #701: Resolving: Function Resolution\"}, \
{\"slug\":\"r2\",\"tester_log_prefix\":\"stage_702\",\"title\":\"Stage #702: Resolving: Variable Resolution\"}, \
Expand All @@ -166,4 +165,4 @@ test_resolving_w_jlox: build
]" \
$(shell pwd)/dist/main.out

test_all: test_scanning_w_jlox test_parsing_w_jlox test_evaluation_w_jlox test_statements_w_jlox test_control_flow_w_jlox
test_all: test_scanning_w_jlox test_parsing_w_jlox test_evaluation_w_jlox test_statements_w_jlox test_control_flow_w_jlox test_functions_w_jlox test_resolving_w_jlox
10 changes: 8 additions & 2 deletions internal/lox/api/run_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ func Run(source string) (string, int, string) {
tokens := scanner.ScanTokens(mockStdout, mockStderr)
parser := lox.NewParser(tokens)
statements := parser.Parse(mockStdout, mockStderr)
lox.Interpret(statements, mockStdout, mockStderr)
locals, err := lox.Resolve(statements)
if err != nil && lox.HadSemanticError {
return "", 65, err.Error()
}

env := lox.NewGlobal()
lox.Interpret(statements, env, locals, mockStdout, mockStderr)

exitCode := 0
if lox.HadParseError {
if lox.HadParseError || lox.HadSemanticError {
exitCode = 65
} else if lox.HadRuntimeError {
exitCode = 70
Expand Down
19 changes: 19 additions & 0 deletions internal/lox/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@ func (e *Environment) Get(name Token) (interface{}, error) {
return nil, MakeRuntimeError(name, fmt.Sprintf("Undefined variable '%s'.", name.Lexeme))
}

// GetAt lookups a variable a certain distance up the chain of environments
func (e *Environment) GetAt(distance int, name Token) (interface{}, error) {
return e.Ancestor(distance).Get(name)
}

// Ancestor reaches an environment up the environment chain
func (e *Environment) Ancestor(distance int) *Environment {
env := e
for range distance {
env = env.enclosing
}
return env
}

// Assign sets a new value to an old variable
func (e *Environment) Assign(name Token, value interface{}) error {
if _, prs := e.values[name.Lexeme]; prs {
Expand All @@ -57,3 +71,8 @@ func (e *Environment) Assign(name Token, value interface{}) error {
}
return MakeRuntimeError(name, fmt.Sprintf("Undefined variable '%s'.", name.Lexeme))
}

// AssignAt sets a new value to an old variable
func (e *Environment) AssignAt(distance int, name Token, value any) error {
return e.Ancestor(distance).Assign(name, value)
}
11 changes: 8 additions & 3 deletions internal/lox/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ type UserFunction struct {
Callable
Declaration *Function
Closure *Environment
Locals Locals // TODO: Pass pointer to Locals
}

// NewUserFunction creates a new UserFunction
func NewUserFunction(declaration *Function, closure *Environment) *UserFunction {
return &UserFunction{Declaration: declaration, Closure: closure}
func NewUserFunction(declaration *Function, closure *Environment, locals Locals) *UserFunction {
return &UserFunction{
Declaration: declaration,
Closure: closure,
Locals: locals,
}
}

// Call executes a user-defined Lox function
Expand All @@ -54,7 +59,7 @@ func (u *UserFunction) Call(arguments []interface{}, globalEnv *Environment, std
}

for _, stmt := range u.Declaration.Body {
_, err := Eval(stmt, env, stdout, stderr)
_, err := Eval(stmt, env, u.Locals, stdout, stderr)

if err != nil {
if r, ok := err.(ReturnError); ok {
Expand Down
13 changes: 11 additions & 2 deletions internal/lox/globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ import (
"time"
)

func InitializeNativeFunctions(env *Environment) {
env.Define("clock", &NativeFunction{
// GlobalEnv is the global environment
var GlobalEnv = NewGlobal()
var globals = GlobalEnv

func InitializeNativeFunctions() {
GlobalEnv.Define("clock", &NativeFunction{
arity: 0,
nativeCall: func(args []interface{}) (interface{}, error) {
exponentNotation := float64(time.Now().Unix())
Expand All @@ -14,3 +18,8 @@ func InitializeNativeFunctions(env *Environment) {
},
})
}

// ResetGlobalEnv resets the GlobalEnv to its original reference
func ResetGlobalEnv() {
GlobalEnv = globals
}
68 changes: 40 additions & 28 deletions internal/lox/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ReturnError struct {
}

func BasicInterpret(expression Expr, stdout io.Writer, stderr io.Writer) {
result, err := Eval(expression, NewGlobal(), stdout, stderr)
result, err := Eval(expression, NewGlobal(), Locals{}, stdout, stderr)
if err != nil {
LogRuntimeError(err, stderr)
return
Expand All @@ -29,27 +29,30 @@ func BasicInterpret(expression Expr, stdout io.Writer, stderr io.Writer) {
fmt.Fprintln(stdout, result)
}

func Interpret(statements []Stmt, stdout io.Writer, stderr io.Writer) {
env := NewGlobal()
InitializeNativeFunctions(env)
func Interpret(statements []Stmt, env *Environment, locals Locals, stdout io.Writer, stderr io.Writer) {
OldGlobalEnv := GlobalEnv
GlobalEnv = env
InitializeNativeFunctions()

for _, stmt := range statements {
_, err := Eval(stmt, env, stdout, stderr)
_, err := Eval(stmt, env, locals, stdout, stderr)
if err != nil {
LogRuntimeError(err, stderr)
return
}
}
GlobalEnv = OldGlobalEnv
}

// Eval evaluates the given AST
func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Writer) (interface{}, error) {
func Eval(node Node, environment *Environment, locals Locals, stdout io.Writer, stderr io.Writer) (interface{}, error) {
switch n := node.(type) {
case *Literal:
return n.Value, nil
case *Grouping:
return Eval(n.Expression, environment, stdout, stderr)
return Eval(n.Expression, environment, locals, stdout, stderr)
case *Unary:
right, err := Eval(n.Right, environment, stdout, stderr)
right, err := Eval(n.Right, environment, locals, stdout, stderr)
if err != nil {
return right, err
} else if n.Operator.Type == MINUS {
Expand All @@ -62,11 +65,11 @@ func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Write
return !isTruthy(right), nil
}
case *Binary:
left, err := Eval(n.Left, environment, stdout, stderr)
left, err := Eval(n.Left, environment, locals, stdout, stderr)
if err != nil {
return left, err
}
right, err := Eval(n.Right, environment, stdout, stderr)
right, err := Eval(n.Right, environment, locals, stdout, stderr)
if err != nil {
return right, err
}
Expand Down Expand Up @@ -161,7 +164,7 @@ func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Write
return isEqual(left, right), nil
}
case *Print:
value, err := Eval(n.Expression, environment, stdout, stderr)
value, err := Eval(n.Expression, environment, locals, stdout, stderr)
if err != nil {
return value, err
}
Expand All @@ -176,14 +179,14 @@ func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Write
}
return nil, nil
case *Expression:
r, err := Eval(n.Expression, environment, stdout, stderr)
r, err := Eval(n.Expression, environment, locals, stdout, stderr)
if err != nil {
return r, err
}
return nil, nil
case *Var:
if n.Initializer != nil {
value, err := Eval(n.Initializer, environment, stdout, stderr)
value, err := Eval(n.Initializer, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
Expand All @@ -194,38 +197,46 @@ func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Write
}
return nil, nil
case *Variable:
return environment.Get(n.Name)
if distance, ok := locals[n]; ok {
return environment.GetAt(distance, n.Name)
}
return GlobalEnv.Get(n.Name)
case *Assign:
value, err := Eval(n.Value, environment, stdout, stderr)
value, err := Eval(n.Value, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
if err = environment.Assign(n.Name, value); err == nil {
if distance, ok := locals[n]; ok {
if err := environment.AssignAt(distance, n.Name, value); err == nil {
return value, nil
}
return nil, err
} else if err := GlobalEnv.Assign(n.Name, value); err == nil {
return value, nil
}
return nil, err
case *Block:
newEnvironment := New(environment)
for _, stmt := range n.Statements {
_, err := Eval(stmt, newEnvironment, stdout, stderr)
_, err := Eval(stmt, newEnvironment, locals, stdout, stderr)
if err != nil {
return nil, err
}
}
return nil, nil
case *If:
condition, err := Eval(n.Condition, environment, stdout, stderr)
condition, err := Eval(n.Condition, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
if isTruthy(condition) {
return Eval(n.ThenBranch, environment, stdout, stderr)
return Eval(n.ThenBranch, environment, locals, stdout, stderr)
} else if n.ElseBranch != nil {
return Eval(n.ElseBranch, environment, stdout, stderr)
return Eval(n.ElseBranch, environment, locals, stdout, stderr)
}
return nil, nil
case *Logical:
left, err := Eval(n.Left, environment, stdout, stderr)
left, err := Eval(n.Left, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
Expand All @@ -239,16 +250,16 @@ func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Write
return left, nil
}
}
return Eval(n.Right, environment, stdout, stderr)
return Eval(n.Right, environment, locals, stdout, stderr)
case *Call:
callee, err := Eval(n.Callee, environment, stdout, stderr)
callee, err := Eval(n.Callee, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}

args := make([]interface{}, 0)
for _, arg := range n.Arguments {
a, err := Eval(arg, environment, stdout, stderr)
a, err := Eval(arg, environment, locals, stdout, stderr)
if err == nil {
args = append(args, a)
} else {
Expand All @@ -268,28 +279,28 @@ func Eval(node Node, environment *Environment, stdout io.Writer, stderr io.Write
return function.Call(args, environment, stdout, stderr)
case *While:
for {
condition, err := Eval(n.Condition, environment, stdout, stderr)
condition, err := Eval(n.Condition, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
if !isTruthy(condition) {
break
}
_, err = Eval(n.Statement, environment, stdout, stderr)
_, err = Eval(n.Statement, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
}
return nil, nil
case *Function:
function := NewUserFunction(n, environment)
function := NewUserFunction(n, environment, locals)
environment.Define(n.Name.Lexeme, function)
return nil, nil
case *Return:
var value interface{}
var err error
if n.Value != nil {
value, err = Eval(n.Value, environment, stdout, stderr)
value, err = Eval(n.Value, environment, locals, stdout, stderr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -331,4 +342,5 @@ func checkNumberOperand(operator Token, value interface{}, msg string) error {
func ClearErrorFlags() {
HadParseError = false
HadRuntimeError = false
HadSemanticError = false
}
Loading
Loading