Skip to content

Commit 38b46bd

Browse files
authored
Implement break statement (#51)
1 parent cc6cc4d commit 38b46bd

File tree

5 files changed

+100
-31
lines changed

5 files changed

+100
-31
lines changed

src/compiler/mod.rs

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::compiler::value::Value;
77
use crate::compiler::var::Var;
88
use crate::expression;
99
use crate::llvm;
10+
use crate::llvm::BasicBlock;
1011
use crate::llvm::PassManager;
1112
use crate::parser;
1213
use crate::parser::Program;
@@ -29,6 +30,7 @@ pub struct Compiler {
2930
fpm: PassManager,
3031
opt: bool,
3132
stack: Vec<Frame>,
33+
after_loop_blocks: Vec<BasicBlock>,
3234
}
3335

3436
impl Visitor<Value> for Compiler {
@@ -184,11 +186,20 @@ impl Visitor<Value> for Compiler {
184186
Value::Bool(b) => {
185187
self.builder.build_cond_br(&b, &then_block, &else_block);
186188
self.builder.position_builder_at_end(&then_block);
189+
190+
let mut is_break = false;
187191
for stmt in &expr.body {
188192
self.walk(stmt);
193+
194+
if matches!(stmt, expression::Expression::Break) {
195+
is_break = true;
196+
break;
197+
}
189198
}
190-
self.builder.create_br(&after_if_block);
191199

200+
if !is_break {
201+
self.builder.create_br(&after_if_block);
202+
}
192203
self.builder.position_builder_at_end(&else_block);
193204
for stmt in &expr.else_body {
194205
self.walk(stmt);
@@ -517,20 +528,31 @@ impl Visitor<Value> for Compiler {
517528

518529
self.builder.position_builder_at_end(&loop_block);
519530

531+
self.after_loop_blocks.push(after_loop_block);
532+
let mut is_break = false;
533+
520534
for stmt in &expr.body {
521535
self.walk(stmt);
536+
if matches!(stmt, expression::Expression::Break) {
537+
is_break = true;
538+
break;
539+
}
522540
}
523-
let term_pred = self.walk(&expr.predicate);
524541

525-
match term_pred {
526-
Value::Bool(b) => {
527-
self.builder
528-
.build_cond_br(&b, &loop_block, &after_loop_block);
542+
if !is_break {
543+
let term_pred = self.walk(&expr.predicate);
544+
545+
match term_pred {
546+
Value::Bool(b) => {
547+
self.builder
548+
.build_cond_br(&b, &loop_block, &after_loop_block);
549+
}
550+
_ => panic!("type error"),
529551
}
530-
_ => panic!("type error"),
552+
self.builder.position_builder_at_end(&after_loop_block);
531553
}
532554

533-
self.builder.position_builder_at_end(&after_loop_block);
555+
self.after_loop_blocks.pop();
534556
}
535557
_ => panic!("type error"),
536558
}
@@ -571,7 +593,12 @@ impl Visitor<Value> for Compiler {
571593
}
572594

573595
fn visit_break(&mut self) -> Value {
574-
todo!()
596+
let after_loop_block = self.after_loop_blocks.first().unwrap();
597+
598+
self.builder.build_br(after_loop_block);
599+
self.builder.position_builder_at_end(after_loop_block);
600+
601+
Value::Break
575602
}
576603

577604
fn visit_program(&mut self, program: parser::Program) -> Value {
@@ -664,6 +691,7 @@ impl Compiler {
664691
Value::GlobalString(_) => self.context.i8_type().pointer_type(0),
665692
Value::Bool(_) => self.context.i1_type(),
666693
Value::Function { typ, .. } => typ.pointer_type(0),
694+
Value::Break => self.context.void_type(),
667695
};
668696

669697
let existing_ptr: Option<llvm::Value> = match self.get_var_ptr(literal) {
@@ -684,6 +712,7 @@ impl Compiler {
684712
Value::Numeric(_)
685713
| Value::String(_)
686714
| Value::GlobalString(_)
715+
| Value::Break
687716
| Value::Vec(_)
688717
| Value::Bool(_) => self.module.add_global(typ, literal),
689718
Value::Function { val: v, .. } => v,
@@ -696,6 +725,7 @@ impl Compiler {
696725
Value::Null => unreachable!(),
697726
Value::String(_) => todo!(),
698727
Value::Function { .. } => (),
728+
Value::Break => unreachable!(),
699729
Value::Bool(_) => {
700730
ptr.set_initializer(self.context.const_bool(false));
701731
}
@@ -708,6 +738,7 @@ impl Compiler {
708738
Value::GlobalString(_) => Var::GlobalString(ptr),
709739
Value::Vec(_) => Var::Vec(ptr),
710740
Value::Bool(_) => Var::Bool(ptr),
741+
Value::Break => Var::Null,
711742
Value::Function {
712743
typ,
713744
return_type,
@@ -744,6 +775,7 @@ impl Compiler {
744775
Value::GlobalString(_) => Var::GlobalString(ptr),
745776
Value::Vec(_) => Var::Vec(ptr),
746777
Value::Bool(_) => Var::Bool(ptr),
778+
Value::Break => Var::Null,
747779
Value::Function {
748780
typ,
749781
return_type,
@@ -925,6 +957,7 @@ impl Compiler {
925957
builder,
926958
engine,
927959
stack: vec![],
960+
after_loop_blocks: vec![],
928961
fpm,
929962
opt: true,
930963
}

src/compiler/value.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ pub enum Value {
1414
return_type: parser::Type,
1515
},
1616
Vec(llvm::Value),
17+
Break,
1718
}

src/llvm/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ impl Builder {
103103
Value(unsafe { LLVMBuildCondBr(self.0, iff.0, then.0, els.0) })
104104
}
105105

106+
pub fn build_br(&self, dest: &BasicBlock) -> Value {
107+
Value(unsafe { LLVMBuildBr(self.0, dest.0) })
108+
}
109+
106110
pub fn build_alloca(&self, el_type: Type, name: &str) -> Value {
107111
Value(unsafe { LLVMBuildAlloca(self.0, el_type.0, c_str(name).as_ptr()) })
108112
}
@@ -327,6 +331,7 @@ impl PassManager {
327331

328332
impl PassManager {}
329333

334+
#[derive(Clone, Copy, Debug)]
330335
pub struct BasicBlock(*mut llvm::LLVMBasicBlock);
331336

332337
impl BasicBlock {

tests/compiler_tests.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,3 +731,33 @@ fn it_panics_when_wrong_unary_operator() {
731731
let mut compiler = Compiler::new(program);
732732
compiler.compile().unwrap();
733733
}
734+
735+
#[test]
736+
fn it_compile_break_in_while() {
737+
let program = Program {
738+
body: vec![Expression::While(While {
739+
predicate: Box::new(Expression::Bool(true)),
740+
body: vec![Expression::Break],
741+
})],
742+
};
743+
744+
let mut compiler = Compiler::new(program);
745+
compiler.compile().unwrap();
746+
}
747+
748+
#[test]
749+
fn it_compile_break_in_while_and_if() {
750+
let program = Program {
751+
body: vec![Expression::While(While {
752+
predicate: Box::new(Expression::Bool(true)),
753+
body: vec![Expression::Conditional(Conditional {
754+
predicate: Box::new(Expression::Bool(true)),
755+
body: vec![Expression::Break],
756+
else_body: vec![],
757+
})],
758+
})],
759+
};
760+
761+
let mut compiler = Compiler::new(program);
762+
compiler.compile().unwrap();
763+
}

tests/parser_tests.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use serde_json::json;
77

88
#[test]
99
fn it_parses_addition() {
10-
let mut parser = Parser::new(&vec![
10+
let mut parser = Parser::new(&[
1111
Token::Numeric(5.2),
1212
Token::Plus,
1313
Token::Numeric(10.0),
@@ -142,7 +142,7 @@ fn it_parses_while_loop() {
142142

143143
#[test]
144144
fn it_returns_error_when_no_curly_after_while_predicate_in_while() {
145-
let mut parser = Parser::new(&vec![
145+
let mut parser = Parser::new(&[
146146
Token::While,
147147
Token::Identifier("x".to_string()),
148148
Token::Less,
@@ -333,7 +333,7 @@ fn it_returns_error_when_no_curly_after_while_predicate_in_if() {
333333

334334
#[test]
335335
fn it_returns_error_when_no_curly_after_while_predicate_in_else() {
336-
let mut parser = Parser::new(&vec![
336+
let mut parser = Parser::new(&[
337337
Token::If,
338338
Token::Identifier("x".to_string()),
339339
Token::Less,
@@ -369,7 +369,7 @@ fn it_displays_correct_syntax_error() {
369369

370370
#[test]
371371
fn it_parses_assignments() {
372-
let mut parser = Parser::new(&vec![
372+
let mut parser = Parser::new(&[
373373
Token::Identifier("x".to_string()),
374374
Token::Equal,
375375
Token::Numeric(10.0),
@@ -400,7 +400,7 @@ fn it_parses_assignments() {
400400

401401
#[test]
402402
fn it_parses_binary_equal() {
403-
let mut parser = Parser::new(&vec![
403+
let mut parser = Parser::new(&[
404404
Token::Identifier("x".to_string()),
405405
Token::DoubleEqual,
406406
Token::Numeric(10.0),
@@ -432,7 +432,7 @@ fn it_parses_binary_equal() {
432432

433433
#[test]
434434
fn it_parses_binary_not_equal() {
435-
let mut parser = Parser::new(&vec![
435+
let mut parser = Parser::new(&[
436436
Token::Identifier("x".to_string()),
437437
Token::NotEqual,
438438
Token::Numeric(10.0),
@@ -464,7 +464,7 @@ fn it_parses_binary_not_equal() {
464464

465465
#[test]
466466
fn it_parses_less_or_equal() {
467-
let mut parser = Parser::new(&vec![
467+
let mut parser = Parser::new(&[
468468
Token::Identifier("x".to_string()),
469469
Token::LessOrEqual,
470470
Token::Numeric(10.0),
@@ -496,7 +496,7 @@ fn it_parses_less_or_equal() {
496496

497497
#[test]
498498
fn it_parses_less() {
499-
let mut parser = Parser::new(&vec![
499+
let mut parser = Parser::new(&[
500500
Token::Identifier("x".to_string()),
501501
Token::Less,
502502
Token::Numeric(10.0),
@@ -528,7 +528,7 @@ fn it_parses_less() {
528528

529529
#[test]
530530
fn it_parses_greater() {
531-
let mut parser = Parser::new(&vec![
531+
let mut parser = Parser::new(&[
532532
Token::Identifier("x".to_string()),
533533
Token::Greater,
534534
Token::Numeric(10.0),
@@ -560,7 +560,7 @@ fn it_parses_greater() {
560560

561561
#[test]
562562
fn it_parses_greater_or_equal() {
563-
let mut parser = Parser::new(&vec![
563+
let mut parser = Parser::new(&[
564564
Token::Identifier("x".to_string()),
565565
Token::GreaterOrEqual,
566566
Token::Numeric(10.0),
@@ -592,7 +592,7 @@ fn it_parses_greater_or_equal() {
592592

593593
#[test]
594594
fn it_parses_subtraction() {
595-
let mut parser = Parser::new(&vec![
595+
let mut parser = Parser::new(&[
596596
Token::Numeric(10.0),
597597
Token::Minus,
598598
Token::Identifier("x".to_string()),
@@ -624,7 +624,7 @@ fn it_parses_subtraction() {
624624

625625
#[test]
626626
fn it_parses_modulo() {
627-
let mut parser = Parser::new(&vec![
627+
let mut parser = Parser::new(&[
628628
Token::Numeric(10.0),
629629
Token::Percent,
630630
Token::Identifier("x".to_string()),
@@ -656,7 +656,7 @@ fn it_parses_modulo() {
656656

657657
#[test]
658658
fn it_parses_multiplication() {
659-
let mut parser = Parser::new(&vec![
659+
let mut parser = Parser::new(&[
660660
Token::Numeric(10.0),
661661
Token::Asterisk,
662662
Token::Identifier("x".to_string()),
@@ -688,7 +688,7 @@ fn it_parses_multiplication() {
688688

689689
#[test]
690690
fn it_parses_division() {
691-
let mut parser = Parser::new(&vec![
691+
let mut parser = Parser::new(&[
692692
Token::Numeric(10.0),
693693
Token::Slash,
694694
Token::Identifier("x".to_string()),
@@ -720,7 +720,7 @@ fn it_parses_division() {
720720

721721
#[test]
722722
fn it_parses_unary_minus() {
723-
let mut parser = Parser::new(&vec![Token::Minus, Token::Numeric(10.0), Token::Eof]);
723+
let mut parser = Parser::new(&[Token::Minus, Token::Numeric(10.0), Token::Eof]);
724724

725725
let ast = parser.parse().unwrap().body;
726726
let json = serde_json::to_value(&ast).unwrap();
@@ -1308,7 +1308,7 @@ fn it_returns_error_when_func_decl_has_no_body() {
13081308

13091309
#[test]
13101310
fn it_parses_func_call() {
1311-
let mut parser = Parser::new(&vec![
1311+
let mut parser = Parser::new(&[
13121312
Token::Identifier("print".to_string()),
13131313
Token::LeftParen,
13141314
Token::RightParen,
@@ -1338,7 +1338,7 @@ fn it_parses_func_call() {
13381338

13391339
#[test]
13401340
fn it_parses_func_call_with_one_arg() {
1341-
let mut parser = Parser::new(&vec![
1341+
let mut parser = Parser::new(&[
13421342
Token::Identifier("print".to_string()),
13431343
Token::LeftParen,
13441344
Token::String("hello".to_string()),
@@ -1439,7 +1439,7 @@ fn it_returns_error_for_call_syntax_on_non_identifiers() {
14391439

14401440
#[test]
14411441
fn it_parses_grouping_expression() {
1442-
let mut parser = Parser::new(&vec![
1442+
let mut parser = Parser::new(&[
14431443
Token::LeftParen,
14441444
Token::String("hello".to_string()),
14451445
Token::RightParen,
@@ -1465,7 +1465,7 @@ fn it_parses_grouping_expression() {
14651465

14661466
#[test]
14671467
fn it_returns_error_for_unterminated_grouping_expresions() {
1468-
let mut parser = Parser::new(&vec![
1468+
let mut parser = Parser::new(&[
14691469
Token::LeftParen,
14701470
Token::String("hello".to_string()),
14711471
Token::Eof,
@@ -1487,7 +1487,7 @@ fn it_returns_error_for_unterminated_grouping_expresions() {
14871487

14881488
#[test]
14891489
fn it_parses_true_bool_literal() {
1490-
let mut parser = Parser::new(&vec![Token::True, Token::Eof]);
1490+
let mut parser = Parser::new(&[Token::True, Token::Eof]);
14911491

14921492
let ast = parser.parse().unwrap().body;
14931493
let json = serde_json::to_value(&ast).unwrap();
@@ -1506,7 +1506,7 @@ fn it_parses_true_bool_literal() {
15061506

15071507
#[test]
15081508
fn it_parses_false_bool_literal() {
1509-
let mut parser = Parser::new(&vec![Token::False, Token::Eof]);
1509+
let mut parser = Parser::new(&[Token::False, Token::Eof]);
15101510

15111511
let ast = parser.parse().unwrap().body;
15121512
let json = serde_json::to_value(&ast).unwrap();
@@ -1525,7 +1525,7 @@ fn it_parses_false_bool_literal() {
15251525

15261526
#[test]
15271527
fn it_parses_break_expression() {
1528-
let mut parser = Parser::new(&vec![Token::Break, Token::Eof]);
1528+
let mut parser = Parser::new(&[Token::Break, Token::Eof]);
15291529

15301530
let ast = parser.parse().unwrap().body;
15311531
let json = serde_json::to_value(&ast).unwrap();

0 commit comments

Comments
 (0)