Skip to content

Commit 1d35e54

Browse files
committed
Refactor Builder and Expression types for improved clarity and type handling
1 parent 8a29746 commit 1d35e54

File tree

4 files changed

+133
-92
lines changed

4 files changed

+133
-92
lines changed

ast/src/builder.rs

Lines changed: 115 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@ use crate::{
55
types::{
66
ArrayIndexAccessExpression, ArrayLiteral, AssertStatement, AssignExpression, AstNode,
77
BinaryExpression, Block, BoolLiteral, BreakStatement, ConstantDefinition, Definition,
8-
EnumDefinition, Expression, ExpressionStatement, ExternalFunctionDefinition,
9-
FunctionCallExpression, FunctionDefinition, FunctionType, GenericType, Identifier,
10-
IfStatement, Literal, Location, LoopStatement, MemberAccessExpression, NumberLiteral,
11-
OperatorKind, Parameter, ParenthesizedExpression, PrefixUnaryExpression, QualifiedName,
12-
ReturnStatement, SimpleType, SourceFile, SpecDefinition, Statement, StringLiteral,
13-
StructDefinition, StructField, Type, TypeArray, TypeDefinition, TypeDefinitionStatement,
14-
TypeQualifiedName, UnaryOperatorKind, UnitLiteral, UseDirective, UzumakiExpression,
15-
VariableDefinitionStatement,
8+
EnumDefinition, Expression, ExternalFunctionDefinition, FunctionCallExpression,
9+
FunctionDefinition, FunctionType, GenericType, Identifier, IfStatement, Literal, Location,
10+
LoopStatement, MemberAccessExpression, NumberLiteral, OperatorKind, Parameter,
11+
ParenthesizedExpression, PrefixUnaryExpression, QualifiedName, ReturnStatement, SimpleType,
12+
SourceFile, SpecDefinition, Statement, StringLiteral, StructDefinition, StructField, Type,
13+
TypeArray, TypeDefinition, TypeDefinitionStatement, TypeQualifiedName, UnaryOperatorKind,
14+
UnitLiteral, UseDirective, UzumakiExpression, VariableDefinitionStatement,
1615
},
1716
};
1817
use tree_sitter::Node;
@@ -241,7 +240,11 @@ impl<'a> Builder<'a, InitState> {
241240
"type_definition_statement" => {
242241
Definition::Type(self.build_type_definition(parent_id, node, code))
243242
}
244-
_ => panic!("Unexpected definition type: {kind}"),
243+
_ => panic!(
244+
"Unexpected definition type: {}, {}",
245+
node.kind(),
246+
Self::get_location(node, code)
247+
),
245248
}
246249
}
247250

@@ -411,33 +414,58 @@ impl<'a> Builder<'a, InitState> {
411414
fn build_block(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> BlockType {
412415
let id = Self::get_node_id();
413416
let location = Self::get_location(node, code);
414-
//FIXME add to arena
415417
match node.kind() {
416-
"assume_block" => BlockType::Assume(Rc::new(Block::new(
417-
parent_id,
418-
location,
419-
self.build_block_statements(id, &node.child_by_field_name("body").unwrap(), code),
420-
))),
421-
"forall_block" => BlockType::Forall(Rc::new(Block::new(
422-
parent_id,
423-
location,
424-
self.build_block_statements(id, &node.child_by_field_name("body").unwrap(), code),
425-
))),
426-
"exists_block" => BlockType::Exists(Rc::new(Block::new(
427-
parent_id,
428-
location,
429-
self.build_block_statements(id, &node.child_by_field_name("body").unwrap(), code),
430-
))),
431-
"unique_block" => BlockType::Unique(Rc::new(Block::new(
432-
parent_id,
433-
location,
434-
self.build_block_statements(id, &node.child_by_field_name("body").unwrap(), code),
435-
))),
436-
_ => BlockType::Block(Rc::new(Block::new(
437-
parent_id,
438-
location,
439-
self.build_block_statements(id, node, code),
440-
))),
418+
"assume_block" => {
419+
let statements = self.build_block_statements(
420+
id,
421+
&node.child_by_field_name("body").unwrap(),
422+
code,
423+
);
424+
let node = Rc::new(Block::new(parent_id, location, statements));
425+
self.arena.add_node(AstNode::Block(node.clone()), parent_id);
426+
BlockType::Assume(node)
427+
}
428+
"forall_block" => {
429+
let statements = self.build_block_statements(
430+
id,
431+
&node.child_by_field_name("body").unwrap(),
432+
code,
433+
);
434+
let node = Rc::new(Block::new(parent_id, location, statements));
435+
self.arena.add_node(AstNode::Block(node.clone()), parent_id);
436+
BlockType::Forall(node)
437+
}
438+
"exists_block" => {
439+
let statements = self.build_block_statements(
440+
id,
441+
&node.child_by_field_name("body").unwrap(),
442+
code,
443+
);
444+
let node = Rc::new(Block::new(parent_id, location, statements));
445+
self.arena.add_node(AstNode::Block(node.clone()), parent_id);
446+
BlockType::Exists(node)
447+
}
448+
"unique_block" => {
449+
let statements = self.build_block_statements(
450+
id,
451+
&node.child_by_field_name("body").unwrap(),
452+
code,
453+
);
454+
let node = Rc::new(Block::new(parent_id, location, statements));
455+
self.arena.add_node(AstNode::Block(node.clone()), parent_id);
456+
BlockType::Unique(node)
457+
}
458+
"block" => {
459+
let statemetns = self.build_block_statements(id, node, code);
460+
let node = Rc::new(Block::new(parent_id, location, statemetns));
461+
self.arena.add_node(AstNode::Block(node.clone()), parent_id);
462+
BlockType::Block(node)
463+
}
464+
_ => panic!(
465+
"Unexpected block type: {}, {}",
466+
node.kind(),
467+
Self::get_location(node, code)
468+
),
441469
}
442470
}
443471

@@ -461,7 +489,7 @@ impl<'a> Builder<'a, InitState> {
461489
Statement::Block(self.build_block(parent_id, node, code))
462490
}
463491
"expression_statement" => {
464-
Statement::Expression(self.build_expression_statement(parent_id, node, code))
492+
Statement::Expression(self.build_expression(parent_id, node, code, None))
465493
}
466494
"return_statement" => {
467495
Statement::Return(self.build_return_statement(parent_id, node, code))
@@ -491,19 +519,6 @@ impl<'a> Builder<'a, InitState> {
491519
}
492520
}
493521

494-
fn build_expression_statement(
495-
&mut self,
496-
parent_id: u32,
497-
node: &Node,
498-
code: &[u8],
499-
) -> ExpressionStatement {
500-
let id = Self::get_node_id();
501-
let location = Self::get_location(node, code);
502-
let expression = self.build_expression(id, &node.child(0).unwrap(), code);
503-
//TODO what to do with this?
504-
ExpressionStatement::new(id, location, expression)
505-
}
506-
507522
fn build_return_statement(
508523
&mut self,
509524
parent_id: u32,
@@ -512,8 +527,12 @@ impl<'a> Builder<'a, InitState> {
512527
) -> Rc<ReturnStatement> {
513528
let id = Self::get_node_id();
514529
let location = Self::get_location(node, code);
515-
let expression =
516-
self.build_expression(id, &node.child_by_field_name("expression").unwrap(), code);
530+
let expression = self.build_expression(
531+
id,
532+
&node.child_by_field_name("expression").unwrap(),
533+
code,
534+
None,
535+
);
517536

518537
let node = Rc::new(ReturnStatement::new(id, location, expression));
519538
self.arena
@@ -531,7 +550,7 @@ impl<'a> Builder<'a, InitState> {
531550
let location = Self::get_location(node, code);
532551
let condition = node
533552
.child_by_field_name("condition")
534-
.map(|n| self.build_expression(id, &n, code));
553+
.map(|n| self.build_expression(id, &n, code, None));
535554
let body_block = node.child_by_field_name("body").unwrap();
536555
let body = self.build_block(id, &body_block, code);
537556
let node = Rc::new(LoopStatement::new(id, location, condition, body));
@@ -544,7 +563,7 @@ impl<'a> Builder<'a, InitState> {
544563
let id = Self::get_node_id();
545564
let location = Self::get_location(node, code);
546565
let condition_node = node.child_by_field_name("condition").unwrap();
547-
let condition = self.build_expression(id, &condition_node, code);
566+
let condition = self.build_expression(id, &condition_node, code, None);
548567
let if_arm_node = node.child_by_field_name("if_arm").unwrap();
549568
let if_arm = self.build_block(id, &if_arm_node, code);
550569
let else_arm = node
@@ -564,15 +583,15 @@ impl<'a> Builder<'a, InitState> {
564583
) -> Rc<VariableDefinitionStatement> {
565584
let id = Self::get_node_id();
566585
let location = Self::get_location(node, code);
586+
let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code);
567587
let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code);
568-
let type_ = self.build_type(id, &node.child_by_field_name("type").unwrap(), code);
569588
let value = node
570589
.child_by_field_name("value")
571-
.map(|n| self.build_expression(id, &n, code));
590+
.map(|n| self.build_expression(id, &n, code, Some(ty.clone())));
572591
let is_undef = node.child_by_field_name("undef").is_some();
573592

574593
let node = Rc::new(VariableDefinitionStatement::new(
575-
id, location, name, type_, value, is_undef,
594+
id, location, name, ty, value, is_undef,
576595
));
577596
self.arena.add_node(
578597
AstNode::VariableDefinitionStatement(node.clone()),
@@ -598,7 +617,13 @@ impl<'a> Builder<'a, InitState> {
598617
node
599618
}
600619

601-
fn build_expression(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Expression {
620+
fn build_expression(
621+
&mut self,
622+
parent_id: u32,
623+
node: &Node,
624+
code: &[u8],
625+
ty: Option<Type>,
626+
) -> Expression {
602627
let node_kind = node.kind();
603628
match node_kind {
604629
"assign_expression" => {
@@ -626,7 +651,7 @@ impl<'a> Builder<'a, InitState> {
626651
"bool_literal" | "string_literal" | "number_literal" | "array_literal"
627652
| "unit_literal" => Expression::Literal(self.build_literal(parent_id, node, code)),
628653
"uzumaki_keyword" => {
629-
Expression::Uzumaki(self.build_uzumaki_expression(parent_id, node, code))
654+
Expression::Uzumaki(self.build_uzumaki_expression(parent_id, node, code, ty))
630655
}
631656
"identifier" => Expression::Identifier(self.build_identifier(parent_id, node, code)),
632657
_ => panic!("Unexpected expression node kind: {node_kind}"),
@@ -641,8 +666,10 @@ impl<'a> Builder<'a, InitState> {
641666
) -> Rc<AssignExpression> {
642667
let id = Self::get_node_id();
643668
let location = Self::get_location(node, code);
644-
let left = self.build_expression(id, &node.child_by_field_name("left").unwrap(), code);
645-
let right = self.build_expression(id, &node.child_by_field_name("right").unwrap(), code);
669+
let left =
670+
self.build_expression(id, &node.child_by_field_name("left").unwrap(), code, None);
671+
let right =
672+
self.build_expression(id, &node.child_by_field_name("right").unwrap(), code, None);
646673

647674
let node = Rc::new(AssignExpression::new(id, location, left, right));
648675
self.arena
@@ -658,8 +685,8 @@ impl<'a> Builder<'a, InitState> {
658685
) -> Rc<ArrayIndexAccessExpression> {
659686
let id = Self::get_node_id();
660687
let location = Self::get_location(node, code);
661-
let array = self.build_expression(id, &node.named_child(0).unwrap(), code);
662-
let index = self.build_expression(id, &node.named_child(1).unwrap(), code);
688+
let array = self.build_expression(id, &node.named_child(0).unwrap(), code, None);
689+
let index = self.build_expression(id, &node.named_child(1).unwrap(), code, None);
663690

664691
let node = Rc::new(ArrayIndexAccessExpression::new(id, location, array, index));
665692
self.arena
@@ -675,8 +702,12 @@ impl<'a> Builder<'a, InitState> {
675702
) -> Rc<MemberAccessExpression> {
676703
let id = Self::get_node_id();
677704
let location = Self::get_location(node, code);
678-
let expression =
679-
self.build_expression(id, &node.child_by_field_name("expression").unwrap(), code);
705+
let expression = self.build_expression(
706+
id,
707+
&node.child_by_field_name("expression").unwrap(),
708+
code,
709+
None,
710+
);
680711
let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code);
681712

682713
let node = Rc::new(MemberAccessExpression::new(id, location, expression, name));
@@ -693,8 +724,12 @@ impl<'a> Builder<'a, InitState> {
693724
) -> Rc<FunctionCallExpression> {
694725
let id = Self::get_node_id();
695726
let location = Self::get_location(node, code);
696-
let function =
697-
self.build_expression(id, &node.child_by_field_name("function").unwrap(), code);
727+
let function = self.build_expression(
728+
id,
729+
&node.child_by_field_name("function").unwrap(),
730+
code,
731+
None,
732+
);
698733
let mut argument_name_expression_map: Vec<(Option<Rc<Identifier>>, Expression)> =
699734
Vec::new();
700735
let mut pending_name: Option<Rc<Identifier>> = None;
@@ -706,13 +741,13 @@ impl<'a> Builder<'a, InitState> {
706741
match field {
707742
"argument_name" => {
708743
if let Expression::Identifier(id) =
709-
self.build_expression(id, &child, code)
744+
self.build_expression(id, &child, code, None)
710745
{
711746
pending_name = Some(id);
712747
}
713748
}
714749
"argument" => {
715-
let expr = self.build_expression(id, &child, code);
750+
let expr = self.build_expression(id, &child, code, None);
716751
let name = pending_name.take();
717752
argument_name_expression_map.push((name, expr));
718753
}
@@ -747,7 +782,7 @@ impl<'a> Builder<'a, InitState> {
747782
) -> Rc<PrefixUnaryExpression> {
748783
let id = Self::get_node_id();
749784
let location = Self::get_location(node, code);
750-
let expression = self.build_expression(id, &node.child(1).unwrap(), code);
785+
let expression = self.build_expression(id, &node.child(1).unwrap(), code, None);
751786

752787
let operator_node = node.child_by_field_name("operator").unwrap();
753788
let operator = match operator_node.kind() {
@@ -771,7 +806,7 @@ impl<'a> Builder<'a, InitState> {
771806
) -> Rc<AssertStatement> {
772807
let id = Self::get_node_id();
773808
let location = Self::get_location(node, code);
774-
let expression = self.build_expression(id, &node.child(1).unwrap(), code);
809+
let expression = self.build_expression(id, &node.child(1).unwrap(), code, None);
775810
let node = Rc::new(AssertStatement::new(id, location, expression));
776811
self.arena
777812
.add_node(AstNode::AssertStatement(node.clone()), parent_id);
@@ -800,7 +835,7 @@ impl<'a> Builder<'a, InitState> {
800835
) -> Rc<ParenthesizedExpression> {
801836
let id = Self::get_node_id();
802837
let location = Self::get_location(node, code);
803-
let expression = self.build_expression(id, &node.child(1).unwrap(), code);
838+
let expression = self.build_expression(id, &node.child(1).unwrap(), code, None);
804839

805840
let node = Rc::new(ParenthesizedExpression::new(id, location, expression));
806841
self.arena
@@ -816,7 +851,8 @@ impl<'a> Builder<'a, InitState> {
816851
) -> Rc<BinaryExpression> {
817852
let id = Self::get_node_id();
818853
let location = Self::get_location(node, code);
819-
let left = self.build_expression(id, &node.child_by_field_name("left").unwrap(), code);
854+
let left =
855+
self.build_expression(id, &node.child_by_field_name("left").unwrap(), code, None);
820856
let operator_node = node.child_by_field_name("operator").unwrap();
821857
let operator_kind = operator_node.kind();
822858
let operator = match operator_kind {
@@ -841,7 +877,8 @@ impl<'a> Builder<'a, InitState> {
841877
_ => panic!("Unexpected operator node: {operator_kind}"),
842878
};
843879

844-
let right = self.build_expression(id, &node.child_by_field_name("right").unwrap(), code);
880+
let right =
881+
self.build_expression(id, &node.child_by_field_name("right").unwrap(), code, None);
845882

846883
let node = Rc::new(BinaryExpression::new(id, location, left, operator, right));
847884
self.arena
@@ -871,7 +908,7 @@ impl<'a> Builder<'a, InitState> {
871908
let mut elements = Vec::new();
872909
let mut cursor = node.walk();
873910
for child in node.named_children(&mut cursor) {
874-
elements.push(self.build_expression(id, &child, code));
911+
elements.push(self.build_expression(id, &child, code, None));
875912
}
876913

877914
let node = Rc::new(ArrayLiteral::new(id, location, elements));
@@ -978,7 +1015,7 @@ impl<'a> Builder<'a, InitState> {
9781015
let element_type = self.build_type(id, &node.child_by_field_name("type").unwrap(), code);
9791016
let size = node
9801017
.child_by_field_name("length")
981-
.map(|n| Box::new(self.build_expression(id, &n, code)));
1018+
.map(|n| Box::new(self.build_expression(id, &n, code, Some(element_type.clone()))));
9821019

9831020
let node = Rc::new(TypeArray::new(id, location, Box::new(element_type), size));
9841021
self.arena
@@ -1085,10 +1122,11 @@ impl<'a> Builder<'a, InitState> {
10851122
parent_id: u32,
10861123
node: &Node,
10871124
code: &[u8],
1125+
ty: Option<Type>,
10881126
) -> Rc<UzumakiExpression> {
10891127
let id = Self::get_node_id();
10901128
let location = Self::get_location(node, code);
1091-
let node = Rc::new(UzumakiExpression::new(id, location));
1129+
let node = Rc::new(UzumakiExpression::new(id, location, ty.unwrap()));
10921130
self.arena
10931131
.add_node(AstNode::UzumakiExpression(node.clone()), parent_id);
10941132
node
@@ -1104,6 +1142,7 @@ impl<'a> Builder<'a, InitState> {
11041142
node
11051143
}
11061144

1145+
#[allow(clippy::cast_possible_truncation)]
11071146
fn get_node_id() -> u32 {
11081147
uuid::Uuid::new_v4().as_u128() as u32
11091148
}

0 commit comments

Comments
 (0)