Skip to content

Commit 9584053

Browse files
committed
Refactor Type Inference and Expression Handling: Update AssignStatement and VariableDefinitionStatement to use RefCell for expressions; enhance type inference logic for Uzumaki expressions; adjust WatEmitter to handle new expression structure.
1 parent 68760e1 commit 9584053

File tree

5 files changed

+75
-195
lines changed

5 files changed

+75
-195
lines changed

ast/src/type_infer.rs

Lines changed: 48 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
use anyhow::bail;
22

33
use crate::types::{Definition, FunctionDefinition, Identifier, Statement, TypeInfo};
4-
#[allow(clippy::all, unused_imports, dead_code)]
5-
use crate::types::{
6-
Expression, Literal, Location, OperatorKind, SimpleType, SourceFile, Type, TypeArray,
7-
};
4+
use crate::types::{Expression, Location, SimpleType, SourceFile, Type};
85
use std::collections::HashMap;
96
use std::rc::Rc;
107

@@ -373,35 +370,50 @@ impl TypeChecker {
373370
type_parameters: &Vec<String>,
374371
) {
375372
match statement {
373+
Statement::Assign(assign_statement) => {
374+
let target_type = self.infer_expression(&mut assign_statement.left.borrow_mut());
375+
if let Expression::Uzumaki(_, ref mut type_info) =
376+
&mut *assign_statement.right.borrow_mut()
377+
{
378+
*type_info = target_type;
379+
} else {
380+
let value_type =
381+
self.infer_expression(&mut assign_statement.right.borrow_mut());
382+
if let (Some(target_type), Some(value_type)) = (target_type, value_type) {
383+
if target_type != value_type {
384+
self.errors.push(format!(
385+
"Cannot assign value of type {value_type:?} to variable of type {target_type:?}"
386+
));
387+
}
388+
}
389+
}
390+
}
376391
Statement::Block(block_type) => {
377392
self.symbol_table.push_scope();
378393
for stmt in &mut block_type.statements() {
379394
self.infer_statement(stmt, return_type.clone(), type_parameters);
380395
}
381396
self.symbol_table.pop_scope();
382397
}
383-
Statement::Expression(expression) => match expression {
384-
Expression::Assign(assign_expression, type_info) => {
385-
let target_type = self.infer_expression(&mut assign_expression.left);
386-
let value_type = self.infer_expression(&mut assign_expression.right);
387-
if let (Some(target_type), Some(value_type)) = (target_type, value_type) {
388-
if !Self::types_equal(&target_type, &value_type) {
389-
self.errors.push(format!(
390-
"Cannot assign value of type {:?} to variable of type {:?}",
391-
value_type, target_type
392-
));
393-
}
394-
}
395-
}
396-
},
398+
Statement::Expression(expression) => {
399+
self.infer_expression(expression);
400+
}
397401
Statement::Return(return_statement) => todo!(),
398402
Statement::Loop(loop_statement) => todo!(),
399403
Statement::Break(break_statement) => todo!(),
400404
Statement::If(if_statement) => todo!(),
401405
Statement::VariableDefinition(variable_definition_statement) => {
402-
if let Some(mut initial_value) = variable_definition_statement.value.clone() {
403-
if let Some(init_type) = self.infer_expression(&mut initial_value) {
404-
if !Self::types_equal(&init_type, &variable_definition_statement.ty) {
406+
let target_type = TypeInfo::new(&variable_definition_statement.ty);
407+
if let Some(initial_value) = variable_definition_statement.value.clone() {
408+
if let Expression::Uzumaki(uzumaki, ref mut type_info) =
409+
&mut *initial_value.borrow_mut()
410+
{
411+
println!("Uzumaki: {uzumaki:?}\n");
412+
*type_info = Some(target_type);
413+
} else if let Some(init_type) =
414+
self.infer_expression(&mut initial_value.borrow_mut())
415+
{
416+
if init_type != TypeInfo::new(&variable_definition_statement.ty) {
405417
self.errors.push(format!(
406418
"Type mismatch in variable definition: expected {:?}, found {:?}",
407419
variable_definition_statement.ty, init_type
@@ -423,23 +435,26 @@ impl TypeChecker {
423435
}
424436
}
425437

426-
fn infer_expression(&mut self, expression: &mut Expression) -> Option<Type> {
438+
fn infer_expression(&mut self, expression: &mut Expression) -> Option<TypeInfo> {
427439
match expression {
428-
Expression::ArrayIndexAccess(array_index_access_expression, type_info) => todo!(),
429-
Expression::MemberAccess(member_access_expression, type_info) => todo!(),
430-
Expression::FunctionCall(function_call_expression, type_info) => todo!(),
431-
Expression::PrefixUnary(prefix_unary_expression, type_info) => todo!(),
432-
Expression::Parenthesized(parenthesized_expression, type_info) => {
440+
Expression::ArrayIndexAccess(array_index_access_expression, ref mut type_info) => {
441+
todo!()
442+
}
443+
Expression::MemberAccess(member_access_expression, ref mut type_info) => todo!(),
444+
Expression::FunctionCall(function_call_expression, ref mut type_info) => todo!(),
445+
Expression::PrefixUnary(prefix_unary_expression, ref mut type_info) => todo!(),
446+
Expression::Parenthesized(parenthesized_expression, ref mut type_info) => {
433447
self.infer_expression(expression)
434448
}
435-
Expression::Binary(binary_expression, type_info) => todo!(),
436-
Expression::Literal(literal, type_info) => todo!(),
437-
Expression::Identifier(identifier, type_info) => todo!(),
438-
Expression::Type(_, type_info) => todo!(),
439-
Expression::Uzumaki(uzumaki_expression, type_info) => todo!(),
449+
Expression::Binary(binary_expression, ref mut type_info) => todo!(),
450+
Expression::Literal(literal, ref mut type_info) => todo!(),
451+
Expression::Identifier(identifier, ref mut type_info) => todo!(),
452+
Expression::Type(_, ref mut type_info) => todo!(),
453+
Expression::Uzumaki(_, ref mut type_info) => type_info.clone(),
440454
}
441455
}
442456

457+
#[allow(dead_code)]
443458
fn types_equal(left: &Type, right: &Type) -> bool {
444459
match (left, right) {
445460
(Type::Array(left), Type::Array(right)) => {
@@ -490,17 +505,13 @@ impl TypeChecker {
490505
}
491506
}
492507
}
493-
return true;
508+
true
494509
}
495510
_ => false,
496511
}
497512
}
498513
}
499514

500-
// pub struct TypeContext<'a> {
501-
// pub symbols: &'a SymbolTable,
502-
// }
503-
504515
// /// Errors during type inference
505516
// #[derive(Debug)]
506517
// pub enum TypeError {
@@ -512,150 +523,3 @@ impl TypeChecker {
512523
// UnknownIdentifier(String, Location),
513524
// Other(String, Location),
514525
// }
515-
516-
// pub fn infer_expr(expr: &Expression, ctx: &TypeContext) -> Result<Type, TypeError> {
517-
// match expr {
518-
// Expression::Literal(lit, _) => match lit {
519-
// Literal::Bool(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
520-
// 0,
521-
// Location::default(),
522-
// "Bool".into(),
523-
// )))),
524-
// Literal::String(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
525-
// 0,
526-
// Location::default(),
527-
// "String".into(),
528-
// )))),
529-
// Literal::Number(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
530-
// 0,
531-
// Location::default(),
532-
// "Number".into(),
533-
// )))),
534-
// Literal::Unit(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
535-
// 0,
536-
// Location::default(),
537-
// "Unit".into(),
538-
// )))),
539-
// Literal::Array(arr) => {
540-
// let mut elem_ty: Option<Type> = None;
541-
// let arr_node = arr.as_ref();
542-
// for e in &arr_node.elements {
543-
// let ty = infer_expr(e, ctx)?;
544-
// if let Some(prev) = &elem_ty {
545-
// if *prev != ty {
546-
// return Err(TypeError::Mismatch {
547-
// expected: prev.clone(),
548-
// found: ty.clone(),
549-
// loc: arr.location.clone(),
550-
// });
551-
// }
552-
// } else {
553-
// elem_ty = Some(ty.clone());
554-
// }
555-
// }
556-
// let element = elem_ty.unwrap_or_else(|| {
557-
// Type::Simple(Rc::new(SimpleType::new(
558-
// 0,
559-
// Location::default(),
560-
// "Unit".into(),
561-
// )))
562-
// });
563-
// Ok(Type::Array(Rc::new(TypeArray::new(
564-
// 0,
565-
// Location::default(),
566-
// element,
567-
// None,
568-
// ))))
569-
// }
570-
// },
571-
// Expression::Identifier(id, _) => {
572-
// let name = &id.name;
573-
// if let Some(ty) = ctx.symbols.lookup(name) {
574-
// Ok(ty)
575-
// } else {
576-
// Err(TypeError::UnknownIdentifier(
577-
// name.clone(),
578-
// id.location.clone(),
579-
// ))
580-
// }
581-
// }
582-
// Expression::Binary(bin, _) => {
583-
// let left_ty = infer_expr(&bin.left, ctx)?;
584-
// let right_ty = infer_expr(&bin.right, ctx)?;
585-
// if left_ty != right_ty {
586-
// return Err(TypeError::Mismatch {
587-
// expected: left_ty.clone(),
588-
// found: right_ty.clone(),
589-
// loc: bin.location.clone(),
590-
// });
591-
// }
592-
// let res_ty = match &bin.operator {
593-
// OperatorKind::Add | OperatorKind::Sub | OperatorKind::Mul | OperatorKind::Div => {
594-
// left_ty.clone()
595-
// }
596-
// OperatorKind::Eq
597-
// | OperatorKind::Ne
598-
// | OperatorKind::Lt
599-
// | OperatorKind::Le
600-
// | OperatorKind::Gt
601-
// | OperatorKind::Ge => Type::Simple(Rc::new(SimpleType::new(
602-
// 0,
603-
// bin.location.clone(),
604-
// "Bool".into(),
605-
// ))),
606-
// op => {
607-
// return Err(TypeError::Other(
608-
// format!("Operator {op:?} not supported"),
609-
// bin.location.clone(),
610-
// ))
611-
// }
612-
// };
613-
// Ok(res_ty)
614-
// }
615-
// _ => Err(TypeError::Other(
616-
// "Type inference not implemented for this expression variant".into(),
617-
// Location::default(),
618-
// )),
619-
// }
620-
// }
621-
622-
// fn traverse_statement(stmt: &crate::types::Statement, ctx: &TypeContext) -> Result<(), TypeError> {
623-
// use crate::types::Statement;
624-
// match stmt {
625-
// Statement::Expression(expr) => {
626-
// infer_expr(expr, ctx)?;
627-
// }
628-
// Statement::Return(ret_rc) => {
629-
// infer_expr(&ret_rc.expression, ctx)?;
630-
// }
631-
// Statement::Assert(assert_rc) => {
632-
// infer_expr(&assert_rc.expression, ctx)?;
633-
// }
634-
// Statement::If(if_rc) => {
635-
// infer_expr(&if_rc.condition, ctx)?;
636-
// traverse_block(&if_rc.if_arm, ctx)?;
637-
// if let Some(else_arm) = &if_rc.else_arm {
638-
// traverse_block(else_arm, ctx)?;
639-
// }
640-
// }
641-
// Statement::Loop(loop_rc) => {
642-
// if let Some(cond) = &loop_rc.condition {
643-
// infer_expr(cond, ctx)?;
644-
// }
645-
// traverse_block(&loop_rc.body, ctx)?;
646-
// }
647-
// Statement::VariableDefinition(vd_rc) => {
648-
// if let Some(init) = &vd_rc.value {
649-
// infer_expr(init, ctx)?;
650-
// }
651-
// }
652-
// Statement::ConstantDefinition(_cd_rc) => {
653-
// // constant definitions have a Literal value; skip or handle separately
654-
// }
655-
// Statement::Block(block_type) => {
656-
// traverse_block(block_type, ctx)?;
657-
// }
658-
// _ => {}
659-
// }
660-
// Ok(())
661-
// }

ast/src/types.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use core::fmt;
22
use std::{
3+
cell::RefCell,
34
fmt::{Display, Formatter},
45
rc::Rc,
56
};
@@ -76,7 +77,14 @@ impl TypeInfo {
7677
.as_ref()
7778
.map(|params| params.iter().map(TypeInfo::new).collect::<Vec<_>>())
7879
.unwrap_or_default();
79-
let return_type = TypeInfo::new(&func.returns);
80+
let return_type = if func.returns.is_some() {
81+
Self::new(func.returns.as_ref().unwrap())
82+
} else {
83+
Self {
84+
name: "Unit".to_string(),
85+
type_params: vec![],
86+
}
87+
};
8088
Self {
8189
name: format!("Function<{}, {}>", param_types.len(), return_type.name),
8290
type_params: vec![],
@@ -440,7 +448,7 @@ ast_nodes! {
440448
pub struct VariableDefinitionStatement {
441449
pub name: Rc<Identifier>,
442450
pub ty: Type,
443-
pub value: Option<Expression>,
451+
pub value: Option<RefCell<Expression>>,
444452
pub is_uzumaki: bool,
445453
}
446454

@@ -450,8 +458,8 @@ ast_nodes! {
450458
}
451459

452460
pub struct AssignStatement {
453-
pub left: Expression,
454-
pub right: Expression,
461+
pub left: RefCell<Expression>,
462+
pub right: RefCell<Expression>,
455463
}
456464

457465
pub struct ArrayIndexAccessExpression {

ast/src/types_impl.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::rc::Rc;
1+
use std::{
2+
cell::{Ref, RefCell},
3+
rc::Rc,
4+
};
25

36
use super::types::{
47
ArrayIndexAccessExpression, ArrayLiteral, AssertStatement, AssignStatement, BinaryExpression,
@@ -370,7 +373,7 @@ impl VariableDefinitionStatement {
370373
location,
371374
name,
372375
ty: type_,
373-
value,
376+
value: value.map(RefCell::new),
374377
is_uzumaki,
375378
}
376379
}
@@ -399,8 +402,8 @@ impl AssignStatement {
399402
AssignStatement {
400403
id,
401404
location,
402-
left,
403-
right,
405+
left: RefCell::new(left),
406+
right: RefCell::new(right),
404407
}
405408
}
406409
}
@@ -619,10 +622,12 @@ impl QualifiedName {
619622
}
620623
}
621624

625+
#[must_use]
622626
pub fn name(&self) -> String {
623627
self.name.name()
624628
}
625629

630+
#[must_use]
626631
pub fn qualifier(&self) -> String {
627632
self.qualifier.name()
628633
}
@@ -639,10 +644,12 @@ impl TypeQualifiedName {
639644
}
640645
}
641646

647+
#[must_use]
642648
pub fn name(&self) -> String {
643649
self.name.name()
644650
}
645651

652+
#[must_use]
646653
pub fn alias(&self) -> String {
647654
self.alias.name()
648655
}

tests/src/ast/expression.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ mod expression_tests {
4747
let mut uzumaki_nodes = uzumaki_nodes.iter().collect::<Vec<_>>();
4848
uzumaki_nodes.sort_by_key(|node| node.start_line());
4949
for (i, node) in uzumaki_nodes.iter().enumerate() {
50-
if let AstNode::Expression(Expression::Uzumaki(_, ty)) = node {
50+
if let AstNode::Expression(Expression::Uzumaki(uzumaki, ty)) = node {
51+
println!("Uzumaki: {uzumaki:?}\n");
5152
assert!(
5253
ty.as_ref().unwrap().name == expected_types[i],
5354
"Expected type {} for UzumakiExpression, found {:?}",

0 commit comments

Comments
 (0)