11use anyhow:: bail;
22
33use crate :: types:: { Definition , FunctionDefinition , Identifier , Statement , TypeInfo } ;
4- #[ allow( clippy:: all, unused_imports, dead_code) ]
5- use crate :: types:: {
6- Expression , Literal , Location , OperatorKind , SimpleType , SourceFile , Type , TypeArray ,
7- } ;
4+ use crate :: types:: { Expression , Location , SimpleType , SourceFile , Type } ;
85use std:: collections:: HashMap ;
96use std:: rc:: Rc ;
107
@@ -373,35 +370,50 @@ impl TypeChecker {
373370 type_parameters : & Vec < String > ,
374371 ) {
375372 match statement {
373+ Statement :: Assign ( assign_statement) => {
374+ let target_type = self . infer_expression ( & mut assign_statement. left . borrow_mut ( ) ) ;
375+ if let Expression :: Uzumaki ( _, ref mut type_info) =
376+ & mut * assign_statement. right . borrow_mut ( )
377+ {
378+ * type_info = target_type;
379+ } else {
380+ let value_type =
381+ self . infer_expression ( & mut assign_statement. right . borrow_mut ( ) ) ;
382+ if let ( Some ( target_type) , Some ( value_type) ) = ( target_type, value_type) {
383+ if target_type != value_type {
384+ self . errors . push ( format ! (
385+ "Cannot assign value of type {value_type:?} to variable of type {target_type:?}"
386+ ) ) ;
387+ }
388+ }
389+ }
390+ }
376391 Statement :: Block ( block_type) => {
377392 self . symbol_table . push_scope ( ) ;
378393 for stmt in & mut block_type. statements ( ) {
379394 self . infer_statement ( stmt, return_type. clone ( ) , type_parameters) ;
380395 }
381396 self . symbol_table . pop_scope ( ) ;
382397 }
383- Statement :: Expression ( expression) => match expression {
384- Expression :: Assign ( assign_expression, type_info) => {
385- let target_type = self . infer_expression ( & mut assign_expression. left ) ;
386- let value_type = self . infer_expression ( & mut assign_expression. right ) ;
387- if let ( Some ( target_type) , Some ( value_type) ) = ( target_type, value_type) {
388- if !Self :: types_equal ( & target_type, & value_type) {
389- self . errors . push ( format ! (
390- "Cannot assign value of type {:?} to variable of type {:?}" ,
391- value_type, target_type
392- ) ) ;
393- }
394- }
395- }
396- } ,
398+ Statement :: Expression ( expression) => {
399+ self . infer_expression ( expression) ;
400+ }
397401 Statement :: Return ( return_statement) => todo ! ( ) ,
398402 Statement :: Loop ( loop_statement) => todo ! ( ) ,
399403 Statement :: Break ( break_statement) => todo ! ( ) ,
400404 Statement :: If ( if_statement) => todo ! ( ) ,
401405 Statement :: VariableDefinition ( variable_definition_statement) => {
402- if let Some ( mut initial_value) = variable_definition_statement. value . clone ( ) {
403- if let Some ( init_type) = self . infer_expression ( & mut initial_value) {
404- if !Self :: types_equal ( & init_type, & variable_definition_statement. ty ) {
406+ let target_type = TypeInfo :: new ( & variable_definition_statement. ty ) ;
407+ if let Some ( initial_value) = variable_definition_statement. value . clone ( ) {
408+ if let Expression :: Uzumaki ( uzumaki, ref mut type_info) =
409+ & mut * initial_value. borrow_mut ( )
410+ {
411+ println ! ( "Uzumaki: {uzumaki:?}\n " ) ;
412+ * type_info = Some ( target_type) ;
413+ } else if let Some ( init_type) =
414+ self . infer_expression ( & mut initial_value. borrow_mut ( ) )
415+ {
416+ if init_type != TypeInfo :: new ( & variable_definition_statement. ty ) {
405417 self . errors . push ( format ! (
406418 "Type mismatch in variable definition: expected {:?}, found {:?}" ,
407419 variable_definition_statement. ty, init_type
@@ -423,23 +435,26 @@ impl TypeChecker {
423435 }
424436 }
425437
426- fn infer_expression ( & mut self , expression : & mut Expression ) -> Option < Type > {
438+ fn infer_expression ( & mut self , expression : & mut Expression ) -> Option < TypeInfo > {
427439 match expression {
428- Expression :: ArrayIndexAccess ( array_index_access_expression, type_info) => todo ! ( ) ,
429- Expression :: MemberAccess ( member_access_expression, type_info) => todo ! ( ) ,
430- Expression :: FunctionCall ( function_call_expression, type_info) => todo ! ( ) ,
431- Expression :: PrefixUnary ( prefix_unary_expression, type_info) => todo ! ( ) ,
432- Expression :: Parenthesized ( parenthesized_expression, type_info) => {
440+ Expression :: ArrayIndexAccess ( array_index_access_expression, ref mut type_info) => {
441+ todo ! ( )
442+ }
443+ Expression :: MemberAccess ( member_access_expression, ref mut type_info) => todo ! ( ) ,
444+ Expression :: FunctionCall ( function_call_expression, ref mut type_info) => todo ! ( ) ,
445+ Expression :: PrefixUnary ( prefix_unary_expression, ref mut type_info) => todo ! ( ) ,
446+ Expression :: Parenthesized ( parenthesized_expression, ref mut type_info) => {
433447 self . infer_expression ( expression)
434448 }
435- Expression :: Binary ( binary_expression, type_info) => todo ! ( ) ,
436- Expression :: Literal ( literal, type_info) => todo ! ( ) ,
437- Expression :: Identifier ( identifier, type_info) => todo ! ( ) ,
438- Expression :: Type ( _, type_info) => todo ! ( ) ,
439- Expression :: Uzumaki ( uzumaki_expression , type_info) => todo ! ( ) ,
449+ Expression :: Binary ( binary_expression, ref mut type_info) => todo ! ( ) ,
450+ Expression :: Literal ( literal, ref mut type_info) => todo ! ( ) ,
451+ Expression :: Identifier ( identifier, ref mut type_info) => todo ! ( ) ,
452+ Expression :: Type ( _, ref mut type_info) => todo ! ( ) ,
453+ Expression :: Uzumaki ( _ , ref mut type_info) => type_info . clone ( ) ,
440454 }
441455 }
442456
457+ #[ allow( dead_code) ]
443458 fn types_equal ( left : & Type , right : & Type ) -> bool {
444459 match ( left, right) {
445460 ( Type :: Array ( left) , Type :: Array ( right) ) => {
@@ -490,17 +505,13 @@ impl TypeChecker {
490505 }
491506 }
492507 }
493- return true ;
508+ true
494509 }
495510 _ => false ,
496511 }
497512 }
498513}
499514
500- // pub struct TypeContext<'a> {
501- // pub symbols: &'a SymbolTable,
502- // }
503-
504515// /// Errors during type inference
505516// #[derive(Debug)]
506517// pub enum TypeError {
@@ -512,150 +523,3 @@ impl TypeChecker {
512523// UnknownIdentifier(String, Location),
513524// Other(String, Location),
514525// }
515-
516- // pub fn infer_expr(expr: &Expression, ctx: &TypeContext) -> Result<Type, TypeError> {
517- // match expr {
518- // Expression::Literal(lit, _) => match lit {
519- // Literal::Bool(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
520- // 0,
521- // Location::default(),
522- // "Bool".into(),
523- // )))),
524- // Literal::String(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
525- // 0,
526- // Location::default(),
527- // "String".into(),
528- // )))),
529- // Literal::Number(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
530- // 0,
531- // Location::default(),
532- // "Number".into(),
533- // )))),
534- // Literal::Unit(_) => Ok(Type::Simple(Rc::new(SimpleType::new(
535- // 0,
536- // Location::default(),
537- // "Unit".into(),
538- // )))),
539- // Literal::Array(arr) => {
540- // let mut elem_ty: Option<Type> = None;
541- // let arr_node = arr.as_ref();
542- // for e in &arr_node.elements {
543- // let ty = infer_expr(e, ctx)?;
544- // if let Some(prev) = &elem_ty {
545- // if *prev != ty {
546- // return Err(TypeError::Mismatch {
547- // expected: prev.clone(),
548- // found: ty.clone(),
549- // loc: arr.location.clone(),
550- // });
551- // }
552- // } else {
553- // elem_ty = Some(ty.clone());
554- // }
555- // }
556- // let element = elem_ty.unwrap_or_else(|| {
557- // Type::Simple(Rc::new(SimpleType::new(
558- // 0,
559- // Location::default(),
560- // "Unit".into(),
561- // )))
562- // });
563- // Ok(Type::Array(Rc::new(TypeArray::new(
564- // 0,
565- // Location::default(),
566- // element,
567- // None,
568- // ))))
569- // }
570- // },
571- // Expression::Identifier(id, _) => {
572- // let name = &id.name;
573- // if let Some(ty) = ctx.symbols.lookup(name) {
574- // Ok(ty)
575- // } else {
576- // Err(TypeError::UnknownIdentifier(
577- // name.clone(),
578- // id.location.clone(),
579- // ))
580- // }
581- // }
582- // Expression::Binary(bin, _) => {
583- // let left_ty = infer_expr(&bin.left, ctx)?;
584- // let right_ty = infer_expr(&bin.right, ctx)?;
585- // if left_ty != right_ty {
586- // return Err(TypeError::Mismatch {
587- // expected: left_ty.clone(),
588- // found: right_ty.clone(),
589- // loc: bin.location.clone(),
590- // });
591- // }
592- // let res_ty = match &bin.operator {
593- // OperatorKind::Add | OperatorKind::Sub | OperatorKind::Mul | OperatorKind::Div => {
594- // left_ty.clone()
595- // }
596- // OperatorKind::Eq
597- // | OperatorKind::Ne
598- // | OperatorKind::Lt
599- // | OperatorKind::Le
600- // | OperatorKind::Gt
601- // | OperatorKind::Ge => Type::Simple(Rc::new(SimpleType::new(
602- // 0,
603- // bin.location.clone(),
604- // "Bool".into(),
605- // ))),
606- // op => {
607- // return Err(TypeError::Other(
608- // format!("Operator {op:?} not supported"),
609- // bin.location.clone(),
610- // ))
611- // }
612- // };
613- // Ok(res_ty)
614- // }
615- // _ => Err(TypeError::Other(
616- // "Type inference not implemented for this expression variant".into(),
617- // Location::default(),
618- // )),
619- // }
620- // }
621-
622- // fn traverse_statement(stmt: &crate::types::Statement, ctx: &TypeContext) -> Result<(), TypeError> {
623- // use crate::types::Statement;
624- // match stmt {
625- // Statement::Expression(expr) => {
626- // infer_expr(expr, ctx)?;
627- // }
628- // Statement::Return(ret_rc) => {
629- // infer_expr(&ret_rc.expression, ctx)?;
630- // }
631- // Statement::Assert(assert_rc) => {
632- // infer_expr(&assert_rc.expression, ctx)?;
633- // }
634- // Statement::If(if_rc) => {
635- // infer_expr(&if_rc.condition, ctx)?;
636- // traverse_block(&if_rc.if_arm, ctx)?;
637- // if let Some(else_arm) = &if_rc.else_arm {
638- // traverse_block(else_arm, ctx)?;
639- // }
640- // }
641- // Statement::Loop(loop_rc) => {
642- // if let Some(cond) = &loop_rc.condition {
643- // infer_expr(cond, ctx)?;
644- // }
645- // traverse_block(&loop_rc.body, ctx)?;
646- // }
647- // Statement::VariableDefinition(vd_rc) => {
648- // if let Some(init) = &vd_rc.value {
649- // infer_expr(init, ctx)?;
650- // }
651- // }
652- // Statement::ConstantDefinition(_cd_rc) => {
653- // // constant definitions have a Literal value; skip or handle separately
654- // }
655- // Statement::Block(block_type) => {
656- // traverse_block(block_type, ctx)?;
657- // }
658- // _ => {}
659- // }
660- // Ok(())
661- // }
0 commit comments