Skip to content

Commit 68760e1

Browse files
committed
Refactor AST and Type Inference: Rename AssignExpression to AssignStatement; update related handling in Builder and TypeChecker; adjust TypeInfo and WatEmitter for consistency
1 parent 6009c2c commit 68760e1

File tree

7 files changed

+209
-82
lines changed

7 files changed

+209
-82
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ inf-wasmparser = "0.0.8"
3333
wat-fmt = "0.0.8"
3434
wasm-fmt = { path = "./wasm-fmt", version = "0.0.1" }
3535
tree-sitter = "0.25.3"
36-
tree-sitter-inference = "0.0.30"
36+
tree-sitter-inference = "0.0.31"
3737
anyhow = "1.0.98"

ast/src/builder.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
symbols::SymbolType,
88
t_ast::TypedAst,
99
types::{
10-
ArrayIndexAccessExpression, ArrayLiteral, AssertStatement, AssignExpression, AstNode,
10+
ArrayIndexAccessExpression, ArrayLiteral, AssertStatement, AssignStatement, AstNode,
1111
BinaryExpression, Block, BlockType, BoolLiteral, BreakStatement, ConstantDefinition,
1212
Definition, EnumDefinition, Expression, ExternalFunctionDefinition, FunctionCallExpression,
1313
FunctionDefinition, FunctionType, GenericType, Identifier, IfStatement, Literal, Location,
@@ -540,6 +540,9 @@ impl<'a> Builder<'a, InitState> {
540540

541541
fn build_statement(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Statement {
542542
match node.kind() {
543+
"assign_statement" => {
544+
Statement::Assign(self.build_assign_statement(parent_id, node, code))
545+
}
543546
"block" | "forall_block" | "assume_block" | "exists_block" | "unique_block" => {
544547
Statement::Block(self.build_block(parent_id, node, code))
545548
}
@@ -675,9 +678,6 @@ impl<'a> Builder<'a, InitState> {
675678
fn build_expression(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Expression {
676679
let node_kind = node.kind();
677680
match node_kind {
678-
"assign_expression" => {
679-
Expression::Assign(self.build_assign_expression(parent_id, node, code), None)
680-
}
681681
"array_index_access_expression" => Expression::ArrayIndexAccess(
682682
self.build_array_index_access_expression(parent_id, node, code),
683683
None,
@@ -716,20 +716,20 @@ impl<'a> Builder<'a, InitState> {
716716
}
717717
}
718718

719-
fn build_assign_expression(
719+
fn build_assign_statement(
720720
&mut self,
721721
parent_id: u32,
722722
node: &Node,
723723
code: &[u8],
724-
) -> Rc<AssignExpression> {
724+
) -> Rc<AssignStatement> {
725725
let id = Self::get_node_id();
726726
let location = Self::get_location(node, code);
727727
let left = self.build_expression(id, &node.child_by_field_name("left").unwrap(), code);
728728
let right = self.build_expression(id, &node.child_by_field_name("right").unwrap(), code);
729729

730-
let node = Rc::new(AssignExpression::new(id, location, left, right));
730+
let node = Rc::new(AssignStatement::new(id, location, left, right));
731731
self.arena.add_node(
732-
AstNode::Expression(Expression::Assign(node.clone(), None)),
732+
AstNode::Statement(Statement::Assign(node.clone())),
733733
parent_id,
734734
);
735735
node
@@ -1134,6 +1134,7 @@ impl<'a> Builder<'a, InitState> {
11341134
let location = Self::get_location(node, code);
11351135
let mut arguments = None;
11361136
let mut cursor = node.walk();
1137+
let mut returns = None;
11371138

11381139
let founded_arguments = node
11391140
.children_by_field_name("argument", &mut cursor)
@@ -1142,7 +1143,9 @@ impl<'a> Builder<'a, InitState> {
11421143
if !founded_arguments.is_empty() {
11431144
arguments = Some(founded_arguments);
11441145
}
1145-
let returns = self.build_type(id, &node.child_by_field_name("returns").unwrap(), code);
1146+
if let Some(returns_type_node) = node.child_by_field_name("returns") {
1147+
returns = Some(self.build_type(id, &returns_type_node, code));
1148+
}
11461149
let node = Rc::new(FunctionType::new(id, location, arguments, returns));
11471150
self.arena.add_node(
11481151
AstNode::Expression(Expression::Type(Type::Function(node.clone()), None)),

ast/src/symbols.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl SymbolTable {
6262
0,
6363
f.location.clone(),
6464
param_types,
65-
return_ty,
65+
Some(return_ty),
6666
)));
6767
table.map.insert(name.clone(), func_ty);
6868
}

ast/src/type_infer.rs

Lines changed: 144 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use anyhow::bail;
22

3-
use crate::types::{Definition, FunctionDefinition, Identifier, TypeInfo};
3+
use crate::types::{Definition, FunctionDefinition, Identifier, Statement, TypeInfo};
44
#[allow(clippy::all, unused_imports, dead_code)]
55
use crate::types::{
66
Expression, Literal, Location, OperatorKind, SimpleType, SourceFile, Type, TypeArray,
@@ -340,12 +340,6 @@ impl TypeChecker {
340340
#[allow(clippy::needless_pass_by_value)]
341341
fn infer_variables(&mut self, function_definition: Rc<FunctionDefinition>) {
342342
self.symbol_table.push_scope();
343-
// let mut generic_type_param_placeholders: HashMap<String, Option<String>> = HashMap::new();
344-
// if let Some(type_parameters) = &function_definition.type_parameters {
345-
// for tp in type_parameters {
346-
// generic_type_param_placeholders.insert(tp.name(), None);
347-
// }
348-
// }
349343
if let Some(arguments) = &function_definition.arguments {
350344
for argument in arguments {
351345
if let Err(err) = self
@@ -356,8 +350,151 @@ impl TypeChecker {
356350
}
357351
}
358352
}
353+
for stmt in &mut function_definition.body.statements() {
354+
self.infer_statement(
355+
stmt,
356+
function_definition.returns.clone(),
357+
&function_definition
358+
.type_parameters
359+
.as_ref()
360+
.unwrap_or(&vec![])
361+
.iter()
362+
.map(|p| p.name())
363+
.collect(),
364+
);
365+
}
359366
self.symbol_table.pop_scope();
360367
}
368+
369+
fn infer_statement(
370+
&mut self,
371+
statement: &mut Statement,
372+
return_type: Option<Type>,
373+
type_parameters: &Vec<String>,
374+
) {
375+
match statement {
376+
Statement::Block(block_type) => {
377+
self.symbol_table.push_scope();
378+
for stmt in &mut block_type.statements() {
379+
self.infer_statement(stmt, return_type.clone(), type_parameters);
380+
}
381+
self.symbol_table.pop_scope();
382+
}
383+
Statement::Expression(expression) => match expression {
384+
Expression::Assign(assign_expression, type_info) => {
385+
let target_type = self.infer_expression(&mut assign_expression.left);
386+
let value_type = self.infer_expression(&mut assign_expression.right);
387+
if let (Some(target_type), Some(value_type)) = (target_type, value_type) {
388+
if !Self::types_equal(&target_type, &value_type) {
389+
self.errors.push(format!(
390+
"Cannot assign value of type {:?} to variable of type {:?}",
391+
value_type, target_type
392+
));
393+
}
394+
}
395+
}
396+
},
397+
Statement::Return(return_statement) => todo!(),
398+
Statement::Loop(loop_statement) => todo!(),
399+
Statement::Break(break_statement) => todo!(),
400+
Statement::If(if_statement) => todo!(),
401+
Statement::VariableDefinition(variable_definition_statement) => {
402+
if let Some(mut initial_value) = variable_definition_statement.value.clone() {
403+
if let Some(init_type) = self.infer_expression(&mut initial_value) {
404+
if !Self::types_equal(&init_type, &variable_definition_statement.ty) {
405+
self.errors.push(format!(
406+
"Type mismatch in variable definition: expected {:?}, found {:?}",
407+
variable_definition_statement.ty, init_type
408+
));
409+
}
410+
}
411+
}
412+
if let Err(err) = self.symbol_table.push_variable_to_scope(
413+
variable_definition_statement.name(),
414+
variable_definition_statement.ty.clone(),
415+
) {
416+
self.errors.push(err.to_string());
417+
}
418+
//TODO handle the case when the variable is not initialized
419+
}
420+
Statement::TypeDefinition(type_definition_statement) => todo!(),
421+
Statement::Assert(assert_statement) => todo!(),
422+
Statement::ConstantDefinition(constant_definition) => todo!(),
423+
}
424+
}
425+
426+
fn infer_expression(&mut self, expression: &mut Expression) -> Option<Type> {
427+
match expression {
428+
Expression::ArrayIndexAccess(array_index_access_expression, type_info) => todo!(),
429+
Expression::MemberAccess(member_access_expression, type_info) => todo!(),
430+
Expression::FunctionCall(function_call_expression, type_info) => todo!(),
431+
Expression::PrefixUnary(prefix_unary_expression, type_info) => todo!(),
432+
Expression::Parenthesized(parenthesized_expression, type_info) => {
433+
self.infer_expression(expression)
434+
}
435+
Expression::Binary(binary_expression, type_info) => todo!(),
436+
Expression::Literal(literal, type_info) => todo!(),
437+
Expression::Identifier(identifier, type_info) => todo!(),
438+
Expression::Type(_, type_info) => todo!(),
439+
Expression::Uzumaki(uzumaki_expression, type_info) => todo!(),
440+
}
441+
}
442+
443+
fn types_equal(left: &Type, right: &Type) -> bool {
444+
match (left, right) {
445+
(Type::Array(left), Type::Array(right)) => {
446+
Self::types_equal(&left.element_type, &right.element_type)
447+
}
448+
(Type::Simple(left), Type::Simple(right)) => left.name == right.name,
449+
(Type::Generic(left), Type::Generic(right)) => {
450+
left.base.name() == right.base.name() && left.parameters == right.parameters
451+
}
452+
(Type::Qualified(left), Type::Qualified(right)) => left.name() == right.name(),
453+
(Type::QualifiedName(left), Type::QualifiedName(right)) => {
454+
left.qualifier() == right.qualifier() && left.name() == right.name()
455+
}
456+
(Type::Custom(left), Type::Custom(right)) => left.name() == right.name(),
457+
(Type::Function(left), Type::Function(right)) => {
458+
let left_has_return_type = left.returns.is_some();
459+
let right_has_return_type = right.returns.is_some();
460+
if left_has_return_type != right_has_return_type {
461+
return false;
462+
}
463+
if left_has_return_type {
464+
if let (Some(left_return_type), Some(right_return_type)) =
465+
(&left.returns, &right.returns)
466+
{
467+
if !Self::types_equal(left_return_type, right_return_type) {
468+
return false;
469+
}
470+
}
471+
}
472+
let left_has_parameters = left.parameters.is_some();
473+
let right_has_parameters = right.parameters.is_some();
474+
if left_has_parameters != right_has_parameters {
475+
return false;
476+
}
477+
if left_has_parameters {
478+
if let (Some(left_parameters), Some(right_parameters)) =
479+
(&left.parameters, &right.parameters)
480+
{
481+
if left_parameters.len() != right_parameters.len() {
482+
return false;
483+
}
484+
for (left_param, right_param) in
485+
left_parameters.iter().zip(right_parameters.iter())
486+
{
487+
if !Self::types_equal(left_param, right_param) {
488+
return false;
489+
}
490+
}
491+
}
492+
}
493+
return true;
494+
}
495+
_ => false,
496+
}
497+
}
361498
}
362499

363500
// pub struct TypeContext<'a> {
@@ -482,44 +619,6 @@ impl TypeChecker {
482619
// }
483620
// }
484621

485-
// pub fn traverse_source_files(
486-
// source_files: &[crate::types::SourceFile],
487-
// symbols: &SymbolTable,
488-
// ) -> Result<(), TypeError> {
489-
// let ctx = TypeContext { symbols };
490-
// for sf in source_files {
491-
// for def in &sf.definitions {
492-
// if let crate::types::Definition::Function(func_rc) = def {
493-
// traverse_function(func_rc, &ctx)?;
494-
// }
495-
// }
496-
// }
497-
// Ok(())
498-
// }
499-
500-
// fn traverse_function(
501-
// func_rc: &std::rc::Rc<crate::types::FunctionDefinition>,
502-
// ctx: &TypeContext,
503-
// ) -> Result<(), TypeError> {
504-
// let func = func_rc.as_ref();
505-
// // TODO: insert parameter types into ctx.symbols if needed
506-
// traverse_block(&func.body, ctx)
507-
// }
508-
509-
// fn traverse_block(
510-
// block_type: &crate::types::BlockType,
511-
// ctx: &TypeContext,
512-
// ) -> Result<(), TypeError> {
513-
// use crate::types::BlockType;
514-
// if let BlockType::Block(b_rc) = block_type {
515-
// let block = b_rc.as_ref();
516-
// for stmt in &block.statements {
517-
// traverse_statement(stmt, ctx)?;
518-
// }
519-
// }
520-
// Ok(())
521-
// }
522-
523622
// fn traverse_statement(stmt: &crate::types::Statement, ctx: &TypeContext) -> Result<(), TypeError> {
524623
// use crate::types::Statement;
525624
// match stmt {

ast/src/types.rs

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ pub struct TypeInfo {
4747

4848
impl TypeInfo {
4949
#[must_use]
50-
pub fn new(ty: Type) -> Self {
51-
match &ty {
50+
pub fn new(ty: &Type) -> Self {
51+
match ty {
5252
Type::Simple(simple) => Self {
5353
name: simple.name.clone(),
5454
type_params: vec![],
@@ -66,22 +66,17 @@ impl TypeInfo {
6666
type_params: vec![],
6767
},
6868
Type::Array(array) => Self {
69-
name: format!("Array<{}>", TypeInfo::new(array.element_type.clone()).name),
69+
name: format!("Array<{}>", TypeInfo::new(&array.element_type).name),
7070
type_params: vec![],
7171
},
7272
Type::Function(func) => {
7373
//REVISIT
7474
let param_types = func
7575
.parameters
7676
.as_ref()
77-
.map(|params| {
78-
params
79-
.iter()
80-
.map(|p| TypeInfo::new(p.clone()))
81-
.collect::<Vec<_>>()
82-
})
77+
.map(|params| params.iter().map(TypeInfo::new).collect::<Vec<_>>())
8378
.unwrap_or_default();
84-
let return_type = TypeInfo::new(func.returns.clone());
79+
let return_type = TypeInfo::new(&func.returns);
8580
Self {
8681
name: format!("Function<{}, {}>", param_types.len(), return_type.name),
8782
type_params: vec![],
@@ -273,6 +268,7 @@ ast_enums! {
273268
pub enum Statement {
274269
@inner_enum Block(BlockType),
275270
@inner_enum Expression(Expression),
271+
Assign(Rc<AssignStatement>),
276272
Return(Rc<ReturnStatement>),
277273
Loop(Rc<LoopStatement>),
278274
Break(Rc<BreakStatement>),
@@ -284,7 +280,6 @@ ast_enums! {
284280
}
285281

286282
pub enum Expression {
287-
Assign(Rc<AssignExpression>, Option<TypeInfo>),
288283
ArrayIndexAccess(Rc<ArrayIndexAccessExpression>, Option<TypeInfo>),
289284
MemberAccess(Rc<MemberAccessExpression>, Option<TypeInfo>),
290285
FunctionCall(Rc<FunctionCallExpression>, Option<TypeInfo>),
@@ -444,7 +439,7 @@ ast_nodes! {
444439

445440
pub struct VariableDefinitionStatement {
446441
pub name: Rc<Identifier>,
447-
pub type_: Type,
442+
pub ty: Type,
448443
pub value: Option<Expression>,
449444
pub is_uzumaki: bool,
450445
}
@@ -454,7 +449,7 @@ ast_nodes! {
454449
pub ty: Type,
455450
}
456451

457-
pub struct AssignExpression {
452+
pub struct AssignStatement {
458453
pub left: Expression,
459454
pub right: Expression,
460455
}
@@ -524,7 +519,7 @@ ast_nodes! {
524519

525520
pub struct FunctionType {
526521
pub parameters: Option<Vec<Type>>,
527-
pub returns: Type,
522+
pub returns: Option<Type>,
528523
}
529524

530525
pub struct QualifiedName {

0 commit comments

Comments
 (0)