Skip to content

Commit 6009c2c

Browse files
committed
Refactor TypeChecker and SymbolTable to enhance variable scope management; update TypeInfo structure and add function definition retrieval in SourceFile
1 parent b58c828 commit 6009c2c

File tree

3 files changed

+180
-51
lines changed

3 files changed

+180
-51
lines changed

ast/src/type_infer.rs

Lines changed: 111 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
use anyhow::bail;
22

3-
use crate::types::{Definition, Identifier, TypeInfo};
3+
use crate::types::{Definition, FunctionDefinition, Identifier, TypeInfo};
44
#[allow(clippy::all, unused_imports, dead_code)]
55
use crate::types::{
66
Expression, Literal, Location, OperatorKind, SimpleType, SourceFile, Type, TypeArray,
77
};
8-
use crate::{arena::Arena, types::GenericType};
98
use std::collections::HashMap;
109
use std::rc::Rc;
1110

@@ -24,7 +23,7 @@ struct SymbolTable {
2423
}
2524

2625
impl SymbolTable {
27-
pub fn new() -> Self {
26+
fn new() -> Self {
2827
let mut table = SymbolTable {
2928
types: HashMap::default(),
3029
functions: HashMap::default(),
@@ -63,7 +62,27 @@ impl SymbolTable {
6362
table
6463
}
6564

66-
fn register_type(&mut self, name: String, type_params: Vec<TypeInfo>) -> anyhow::Result<()> {
65+
fn push_scope(&mut self) {
66+
self.variables.push(HashMap::new());
67+
}
68+
69+
fn pop_scope(&mut self) {
70+
self.variables.pop();
71+
}
72+
73+
fn push_variable_to_scope(&mut self, name: String, var_type: Type) -> anyhow::Result<()> {
74+
if let Some(scope) = self.variables.last_mut() {
75+
if scope.contains_key(&name) {
76+
bail!("Variable `{name}` already declared in this scope");
77+
}
78+
scope.insert(name, var_type);
79+
Ok(())
80+
} else {
81+
bail!("No active scope to push variables".to_string())
82+
}
83+
}
84+
85+
fn register_type(&mut self, name: String, type_params: Vec<String>) -> anyhow::Result<()> {
6786
if self.types.contains_key(&name) {
6887
bail!("Type `{name}` is already defined")
6988
}
@@ -80,7 +99,7 @@ impl SymbolTable {
8099
return_type: Type,
81100
) -> Result<(), String> {
82101
if self.functions.contains_key(&name) {
83-
return Err(format!("Function `{}` is already defined", name));
102+
return Err(format!("Function `{name}` is already defined"));
84103
}
85104
self.functions.insert(
86105
name.clone(),
@@ -119,6 +138,14 @@ impl TypeChecker {
119138
if !self.errors.is_empty() {
120139
bail!(std::mem::take(&mut self.errors).join("; ")) //TODO: handle it better
121140
}
141+
for source_file in program {
142+
for function_definition in &mut source_file.function_definitions() {
143+
self.infer_variables(function_definition.clone());
144+
}
145+
}
146+
if !self.errors.is_empty() {
147+
bail!(std::mem::take(&mut self.errors).join("; ")) //TODO: handle it better
148+
}
122149
Ok(())
123150
}
124151

@@ -131,13 +158,7 @@ impl TypeChecker {
131158
let type_params = generic_type
132159
.parameters
133160
.iter()
134-
.map(|param| {
135-
if let Type::Generic(generic_param) = param {
136-
Self::construct_generic_type_info(generic_param.clone())
137-
} else {
138-
panic!("Expected a generic type parameter")
139-
}
140-
})
161+
.map(|param| param.name())
141162
.collect();
142163
self.symbol_table
143164
.register_type(type_definition.name(), type_params)
@@ -197,42 +218,55 @@ impl TypeChecker {
197218
}
198219
}
199220

200-
fn construct_generic_type_info(generic_type_definition: Rc<GenericType>) -> TypeInfo {
201-
Self::construct_type_info(generic_type_definition, vec![])
202-
}
203-
204-
#[allow(clippy::needless_pass_by_value)]
205-
fn construct_type_info(
206-
generic_type_definition: Rc<GenericType>,
207-
type_params: Vec<TypeInfo>,
208-
) -> TypeInfo {
209-
let name = generic_type_definition.base.name();
210-
let mut type_info = TypeInfo { name, type_params };
211-
for param in &generic_type_definition.parameters {
212-
if let Type::Generic(generic_param) = param {
213-
let param_info = Self::construct_generic_type_info(generic_param.clone());
214-
type_info.type_params.push(param_info);
215-
}
216-
}
217-
type_info
218-
}
219-
220-
//TODO continue implementing this function
221221
fn collect_function_and_constant_definitions(&mut self, program: &mut Vec<SourceFile>) {
222222
for sf in program {
223223
for definition in &sf.definitions {
224224
match definition {
225225
Definition::Constant(constant_definition) => todo!(),
226226
Definition::Function(function_definition) => {
227227
for param in function_definition.arguments.as_ref().unwrap_or(&vec![]) {
228-
self.validate_type(&param.ty, None);
228+
self.validate_type(
229+
&param.ty,
230+
function_definition.type_parameters.as_ref(),
231+
);
229232
}
230233
if let Some(return_type) = &function_definition.returns {
231234
self.validate_type(
232235
return_type,
233-
function_definition.type_parameters.clone(),
236+
function_definition.type_parameters.as_ref(),
234237
);
235238
}
239+
if !self.errors.is_empty() {
240+
continue;
241+
}
242+
if let Err(err) = self.symbol_table.register_function(
243+
function_definition.name(),
244+
function_definition
245+
.type_parameters
246+
.as_ref()
247+
.unwrap_or(&vec![])
248+
.iter()
249+
.map(|param| param.name())
250+
.collect(),
251+
function_definition
252+
.arguments
253+
.as_ref()
254+
.unwrap_or(&vec![])
255+
.iter()
256+
.map(|param| param.ty.clone())
257+
.collect(),
258+
function_definition
259+
.returns
260+
.as_ref()
261+
.unwrap_or(&Type::Simple(Rc::new(SimpleType::new(
262+
0,
263+
Location::default(),
264+
"Unit".into(),
265+
))))
266+
.clone(),
267+
) {
268+
self.errors.push(err);
269+
}
236270
}
237271
Definition::ExternalFunction(external_function_definition) => {
238272
todo!()
@@ -245,7 +279,7 @@ impl TypeChecker {
245279
}
246280
}
247281

248-
fn validate_type(&mut self, ty: &Type, type_parameters: Option<Vec<Rc<Type>>>) {
282+
fn validate_type(&mut self, ty: &Type, type_parameters: Option<&Vec<Rc<Identifier>>>) {
249283
match ty {
250284
Type::Array(type_array) => self.validate_type(&type_array.element_type, None),
251285
Type::Simple(simple_type) => {
@@ -264,17 +298,26 @@ impl TypeChecker {
264298
.push(format!("Unknown type `{}`", generic_type.base.name()));
265299
}
266300
if let Some(type_params) = &type_parameters {
301+
if type_params.len() != generic_type.parameters.len() {
302+
self.errors.push(format!(
303+
"Type parameter count mismatch for `{}`: expected {}, found {}",
304+
generic_type.base.name(),
305+
generic_type.parameters.len(),
306+
type_params.len()
307+
));
308+
}
309+
let generic_param_names: Vec<String> = generic_type
310+
.parameters
311+
.iter()
312+
.map(|param| param.name())
313+
.collect();
267314
for param in &generic_type.parameters {
268-
if let Type::Generic(generic_param) = param {
269-
if !type_params
270-
.iter()
271-
.any(|tp| tp.name == generic_param.base.name)
272-
{
273-
self.errors.push(format!(
274-
"Unknown type parameter `{}` in generic type `{}`",
275-
generic_param.base.name, generic_type.base.name
276-
));
277-
}
315+
if !generic_param_names.contains(&param.name()) {
316+
self.errors.push(format!(
317+
"Type parameter `{}` not found in `{}`",
318+
param.name(),
319+
generic_type.base.name()
320+
));
278321
}
279322
}
280323
}
@@ -293,6 +336,28 @@ impl TypeChecker {
293336
}
294337
}
295338
}
339+
340+
#[allow(clippy::needless_pass_by_value)]
341+
fn infer_variables(&mut self, function_definition: Rc<FunctionDefinition>) {
342+
self.symbol_table.push_scope();
343+
// let mut generic_type_param_placeholders: HashMap<String, Option<String>> = HashMap::new();
344+
// if let Some(type_parameters) = &function_definition.type_parameters {
345+
// for tp in type_parameters {
346+
// generic_type_param_placeholders.insert(tp.name(), None);
347+
// }
348+
// }
349+
if let Some(arguments) = &function_definition.arguments {
350+
for argument in arguments {
351+
if let Err(err) = self
352+
.symbol_table
353+
.push_variable_to_scope(argument.name(), argument.ty.clone())
354+
{
355+
self.errors.push(err.to_string());
356+
}
357+
}
358+
}
359+
self.symbol_table.pop_scope();
360+
}
296361
}
297362

298363
// pub struct TypeContext<'a> {

ast/src/types.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,60 @@ impl Location {
4141
#[derive(Debug, Eq, PartialEq, Clone)]
4242
pub struct TypeInfo {
4343
pub name: String,
44-
pub type_params: Vec<TypeInfo>,
44+
pub type_params: Vec<String>,
4545
// (Field type information could be added here if needed for struct field checking.)
4646
}
4747

48+
impl TypeInfo {
49+
#[must_use]
50+
pub fn new(ty: Type) -> Self {
51+
match &ty {
52+
Type::Simple(simple) => Self {
53+
name: simple.name.clone(),
54+
type_params: vec![],
55+
},
56+
Type::Generic(generic) => Self {
57+
name: generic.base.name.clone(),
58+
type_params: generic.parameters.iter().map(|p| p.name.clone()).collect(),
59+
},
60+
Type::QualifiedName(qualified_name) => Self {
61+
name: qualified_name.qualifier.name.clone(),
62+
type_params: vec![],
63+
},
64+
Type::Qualified(qualified) => Self {
65+
name: qualified.alias.name.clone(),
66+
type_params: vec![],
67+
},
68+
Type::Array(array) => Self {
69+
name: format!("Array<{}>", TypeInfo::new(array.element_type.clone()).name),
70+
type_params: vec![],
71+
},
72+
Type::Function(func) => {
73+
//REVISIT
74+
let param_types = func
75+
.parameters
76+
.as_ref()
77+
.map(|params| {
78+
params
79+
.iter()
80+
.map(|p| TypeInfo::new(p.clone()))
81+
.collect::<Vec<_>>()
82+
})
83+
.unwrap_or_default();
84+
let return_type = TypeInfo::new(func.returns.clone());
85+
Self {
86+
name: format!("Function<{}, {}>", param_types.len(), return_type.name),
87+
type_params: vec![],
88+
}
89+
}
90+
Type::Custom(custom) => Self {
91+
name: custom.name.clone(),
92+
type_params: vec![],
93+
},
94+
}
95+
}
96+
}
97+
4898
impl Display for Location {
4999
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
50100
write!(

ast/src/types_impl.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ impl SourceFile {
2424
}
2525
}
2626

27+
impl SourceFile {
28+
#[must_use]
29+
pub fn function_definitions(&self) -> Vec<Rc<FunctionDefinition>> {
30+
self.definitions
31+
.iter()
32+
.filter_map(|def| match def {
33+
Definition::Function(func) => Some(func.clone()),
34+
_ => None,
35+
})
36+
.collect()
37+
}
38+
}
39+
2740
impl UseDirective {
2841
#[must_use]
2942
pub fn new(
@@ -129,6 +142,7 @@ impl Identifier {
129142
Identifier { id, location, name }
130143
}
131144

145+
#[must_use]
132146
pub fn name(&self) -> String {
133147
self.name.clone()
134148
}
@@ -176,8 +190,8 @@ impl FunctionDefinition {
176190
}
177191

178192
#[must_use]
179-
pub fn name(&self) -> &str {
180-
&self.name.name
193+
pub fn name(&self) -> String {
194+
self.name.name.clone()
181195
}
182196

183197
#[must_use]
@@ -247,8 +261,8 @@ impl Parameter {
247261
}
248262

249263
#[must_use]
250-
pub fn name(&self) -> &str {
251-
&self.name.name
264+
pub fn name(&self) -> String {
265+
self.name.name.clone()
252266
}
253267
}
254268

0 commit comments

Comments
 (0)