11use anyhow:: bail;
22
3- use crate :: types:: { Definition , FunctionDefinition , Identifier , TypeInfo } ;
3+ use crate :: types:: { Definition , FunctionDefinition , Identifier , Statement , TypeInfo } ;
44#[ allow( clippy:: all, unused_imports, dead_code) ]
55use crate :: types:: {
66 Expression , Literal , Location , OperatorKind , SimpleType , SourceFile , Type , TypeArray ,
@@ -340,12 +340,6 @@ impl TypeChecker {
340340 #[ allow( clippy:: needless_pass_by_value) ]
341341 fn infer_variables ( & mut self , function_definition : Rc < FunctionDefinition > ) {
342342 self . symbol_table . push_scope ( ) ;
343- // let mut generic_type_param_placeholders: HashMap<String, Option<String>> = HashMap::new();
344- // if let Some(type_parameters) = &function_definition.type_parameters {
345- // for tp in type_parameters {
346- // generic_type_param_placeholders.insert(tp.name(), None);
347- // }
348- // }
349343 if let Some ( arguments) = & function_definition. arguments {
350344 for argument in arguments {
351345 if let Err ( err) = self
@@ -356,8 +350,151 @@ impl TypeChecker {
356350 }
357351 }
358352 }
353+ for stmt in & mut function_definition. body . statements ( ) {
354+ self . infer_statement (
355+ stmt,
356+ function_definition. returns . clone ( ) ,
357+ & function_definition
358+ . type_parameters
359+ . as_ref ( )
360+ . unwrap_or ( & vec ! [ ] )
361+ . iter ( )
362+ . map ( |p| p. name ( ) )
363+ . collect ( ) ,
364+ ) ;
365+ }
359366 self . symbol_table . pop_scope ( ) ;
360367 }
368+
369+ fn infer_statement (
370+ & mut self ,
371+ statement : & mut Statement ,
372+ return_type : Option < Type > ,
373+ type_parameters : & Vec < String > ,
374+ ) {
375+ match statement {
376+ Statement :: Block ( block_type) => {
377+ self . symbol_table . push_scope ( ) ;
378+ for stmt in & mut block_type. statements ( ) {
379+ self . infer_statement ( stmt, return_type. clone ( ) , type_parameters) ;
380+ }
381+ self . symbol_table . pop_scope ( ) ;
382+ }
383+ Statement :: Expression ( expression) => match expression {
384+ Expression :: Assign ( assign_expression, type_info) => {
385+ let target_type = self . infer_expression ( & mut assign_expression. left ) ;
386+ let value_type = self . infer_expression ( & mut assign_expression. right ) ;
387+ if let ( Some ( target_type) , Some ( value_type) ) = ( target_type, value_type) {
388+ if !Self :: types_equal ( & target_type, & value_type) {
389+ self . errors . push ( format ! (
390+ "Cannot assign value of type {:?} to variable of type {:?}" ,
391+ value_type, target_type
392+ ) ) ;
393+ }
394+ }
395+ }
396+ } ,
397+ Statement :: Return ( return_statement) => todo ! ( ) ,
398+ Statement :: Loop ( loop_statement) => todo ! ( ) ,
399+ Statement :: Break ( break_statement) => todo ! ( ) ,
400+ Statement :: If ( if_statement) => todo ! ( ) ,
401+ Statement :: VariableDefinition ( variable_definition_statement) => {
402+ if let Some ( mut initial_value) = variable_definition_statement. value . clone ( ) {
403+ if let Some ( init_type) = self . infer_expression ( & mut initial_value) {
404+ if !Self :: types_equal ( & init_type, & variable_definition_statement. ty ) {
405+ self . errors . push ( format ! (
406+ "Type mismatch in variable definition: expected {:?}, found {:?}" ,
407+ variable_definition_statement. ty, init_type
408+ ) ) ;
409+ }
410+ }
411+ }
412+ if let Err ( err) = self . symbol_table . push_variable_to_scope (
413+ variable_definition_statement. name ( ) ,
414+ variable_definition_statement. ty . clone ( ) ,
415+ ) {
416+ self . errors . push ( err. to_string ( ) ) ;
417+ }
418+ //TODO handle the case when the variable is not initialized
419+ }
420+ Statement :: TypeDefinition ( type_definition_statement) => todo ! ( ) ,
421+ Statement :: Assert ( assert_statement) => todo ! ( ) ,
422+ Statement :: ConstantDefinition ( constant_definition) => todo ! ( ) ,
423+ }
424+ }
425+
426+ fn infer_expression ( & mut self , expression : & mut Expression ) -> Option < Type > {
427+ match expression {
428+ Expression :: ArrayIndexAccess ( array_index_access_expression, type_info) => todo ! ( ) ,
429+ Expression :: MemberAccess ( member_access_expression, type_info) => todo ! ( ) ,
430+ Expression :: FunctionCall ( function_call_expression, type_info) => todo ! ( ) ,
431+ Expression :: PrefixUnary ( prefix_unary_expression, type_info) => todo ! ( ) ,
432+ Expression :: Parenthesized ( parenthesized_expression, type_info) => {
433+ self . infer_expression ( expression)
434+ }
435+ Expression :: Binary ( binary_expression, type_info) => todo ! ( ) ,
436+ Expression :: Literal ( literal, type_info) => todo ! ( ) ,
437+ Expression :: Identifier ( identifier, type_info) => todo ! ( ) ,
438+ Expression :: Type ( _, type_info) => todo ! ( ) ,
439+ Expression :: Uzumaki ( uzumaki_expression, type_info) => todo ! ( ) ,
440+ }
441+ }
442+
443+ fn types_equal ( left : & Type , right : & Type ) -> bool {
444+ match ( left, right) {
445+ ( Type :: Array ( left) , Type :: Array ( right) ) => {
446+ Self :: types_equal ( & left. element_type , & right. element_type )
447+ }
448+ ( Type :: Simple ( left) , Type :: Simple ( right) ) => left. name == right. name ,
449+ ( Type :: Generic ( left) , Type :: Generic ( right) ) => {
450+ left. base . name ( ) == right. base . name ( ) && left. parameters == right. parameters
451+ }
452+ ( Type :: Qualified ( left) , Type :: Qualified ( right) ) => left. name ( ) == right. name ( ) ,
453+ ( Type :: QualifiedName ( left) , Type :: QualifiedName ( right) ) => {
454+ left. qualifier ( ) == right. qualifier ( ) && left. name ( ) == right. name ( )
455+ }
456+ ( Type :: Custom ( left) , Type :: Custom ( right) ) => left. name ( ) == right. name ( ) ,
457+ ( Type :: Function ( left) , Type :: Function ( right) ) => {
458+ let left_has_return_type = left. returns . is_some ( ) ;
459+ let right_has_return_type = right. returns . is_some ( ) ;
460+ if left_has_return_type != right_has_return_type {
461+ return false ;
462+ }
463+ if left_has_return_type {
464+ if let ( Some ( left_return_type) , Some ( right_return_type) ) =
465+ ( & left. returns , & right. returns )
466+ {
467+ if !Self :: types_equal ( left_return_type, right_return_type) {
468+ return false ;
469+ }
470+ }
471+ }
472+ let left_has_parameters = left. parameters . is_some ( ) ;
473+ let right_has_parameters = right. parameters . is_some ( ) ;
474+ if left_has_parameters != right_has_parameters {
475+ return false ;
476+ }
477+ if left_has_parameters {
478+ if let ( Some ( left_parameters) , Some ( right_parameters) ) =
479+ ( & left. parameters , & right. parameters )
480+ {
481+ if left_parameters. len ( ) != right_parameters. len ( ) {
482+ return false ;
483+ }
484+ for ( left_param, right_param) in
485+ left_parameters. iter ( ) . zip ( right_parameters. iter ( ) )
486+ {
487+ if !Self :: types_equal ( left_param, right_param) {
488+ return false ;
489+ }
490+ }
491+ }
492+ }
493+ return true ;
494+ }
495+ _ => false ,
496+ }
497+ }
361498}
362499
363500// pub struct TypeContext<'a> {
@@ -482,44 +619,6 @@ impl TypeChecker {
482619// }
483620// }
484621
485- // pub fn traverse_source_files(
486- // source_files: &[crate::types::SourceFile],
487- // symbols: &SymbolTable,
488- // ) -> Result<(), TypeError> {
489- // let ctx = TypeContext { symbols };
490- // for sf in source_files {
491- // for def in &sf.definitions {
492- // if let crate::types::Definition::Function(func_rc) = def {
493- // traverse_function(func_rc, &ctx)?;
494- // }
495- // }
496- // }
497- // Ok(())
498- // }
499-
500- // fn traverse_function(
501- // func_rc: &std::rc::Rc<crate::types::FunctionDefinition>,
502- // ctx: &TypeContext,
503- // ) -> Result<(), TypeError> {
504- // let func = func_rc.as_ref();
505- // // TODO: insert parameter types into ctx.symbols if needed
506- // traverse_block(&func.body, ctx)
507- // }
508-
509- // fn traverse_block(
510- // block_type: &crate::types::BlockType,
511- // ctx: &TypeContext,
512- // ) -> Result<(), TypeError> {
513- // use crate::types::BlockType;
514- // if let BlockType::Block(b_rc) = block_type {
515- // let block = b_rc.as_ref();
516- // for stmt in &block.statements {
517- // traverse_statement(stmt, ctx)?;
518- // }
519- // }
520- // Ok(())
521- // }
522-
523622// fn traverse_statement(stmt: &crate::types::Statement, ctx: &TypeContext) -> Result<(), TypeError> {
524623// use crate::types::Statement;
525624// match stmt {
0 commit comments