From 600b79f6c127599a04e5ec0f4382ed15708ae673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Fri, 23 Jan 2026 13:45:17 -0500 Subject: [PATCH 1/3] Begin defining a public API for types --- .../src/backend/ir/ast/node_extensions/mod.rs | 3 + .../backend/ir/ast/node_extensions/types.rs | 294 ++++++++++++++++++ .../cargo/crate/src/backend/types/mod.rs | 1 + 3 files changed, 298 insertions(+) create mode 100644 crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/mod.rs b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/mod.rs index f9b4c6bacd..3d7965bafd 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/mod.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/mod.rs @@ -12,6 +12,9 @@ pub use identifiers::{ Identifier, IdentifierStruct, Reference, YulIdentifier, YulIdentifierStruct, }; +mod types; +pub use types::Type; + impl SourceUnitStruct { pub fn file_id(&self) -> String { self.semantic diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs new file mode 100644 index 0000000000..62a875c828 --- /dev/null +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs @@ -0,0 +1,294 @@ +use paste::paste; +use std::rc::Rc; + +use crate::backend::{ + types::{self, DataLocation, TypeId}, + SemanticAnalysis, +}; + +use super::Definition; + +// __SLANG_TYPE_TYPES__ keep in sync with binder types +#[derive(Clone)] +pub enum Type { + Address(AddressType), + Array(ArrayType), + Boolean(BooleanType), + ByteArray(ByteArrayType), + Bytes(BytesType), + Contract(ContractType), + Enum(EnumType), + FixedPointNumber(FixedPointNumberType), + Function(FunctionType), + Integer(IntegerType), + Interface(InterfaceType), + Literal(LiteralType), + Mapping(MappingType), + String(StringType), + Struct(StructType), + Tuple(TupleType), + UserDefinedValue(UserDefinedValueType), + Void(VoidType), +} + +macro_rules! define_type_variant { + ($type:ident) => { + paste! { + #[derive(Clone)] + pub struct [<$type Type>] { + type_id: TypeId, + semantic: Rc, + } + + impl [<$type Type>] { + #[allow(unused)] + fn internal_type(&self) -> &types::Type { + self.semantic.types.get_type_by_id(self.type_id) + } + } + } + }; +} + +define_type_variant!(Address); +define_type_variant!(Array); +define_type_variant!(Boolean); +define_type_variant!(ByteArray); +define_type_variant!(Bytes); +define_type_variant!(Contract); +define_type_variant!(Enum); +define_type_variant!(FixedPointNumber); +define_type_variant!(Function); +define_type_variant!(Integer); +define_type_variant!(Interface); +define_type_variant!(Literal); +define_type_variant!(Mapping); +define_type_variant!(String); +define_type_variant!(Struct); +define_type_variant!(Tuple); +define_type_variant!(UserDefinedValue); +define_type_variant!(Void); + +impl Type { + pub fn create(type_id: TypeId, semantic: &Rc) -> Self { + let type_ = semantic.types().get_type_by_id(type_id); + let semantic = Rc::clone(semantic); + match type_ { + types::Type::Address { .. } => Self::Address(AddressType { type_id, semantic }), + types::Type::Array { .. } => Self::Array(ArrayType { type_id, semantic }), + types::Type::Boolean => Self::Boolean(BooleanType { type_id, semantic }), + types::Type::ByteArray { .. } => Self::ByteArray(ByteArrayType { type_id, semantic }), + types::Type::Bytes { .. } => Self::Bytes(BytesType { type_id, semantic }), + types::Type::Contract { .. } => Self::Contract(ContractType { type_id, semantic }), + types::Type::Enum { .. } => Self::Enum(EnumType { type_id, semantic }), + types::Type::FixedPointNumber { .. } => { + Self::FixedPointNumber(FixedPointNumberType { type_id, semantic }) + } + types::Type::Function(_) => Self::Function(FunctionType { type_id, semantic }), + types::Type::Integer { .. } => Self::Integer(IntegerType { type_id, semantic }), + types::Type::Interface { .. } => Self::Interface(InterfaceType { type_id, semantic }), + types::Type::Literal(_) => Self::Literal(LiteralType { type_id, semantic }), + types::Type::Mapping { .. } => Self::Mapping(MappingType { type_id, semantic }), + types::Type::String { .. } => Self::String(StringType { type_id, semantic }), + types::Type::Struct { .. } => Self::Struct(StructType { type_id, semantic }), + types::Type::Tuple { .. } => Self::Tuple(TupleType { type_id, semantic }), + types::Type::UserDefinedValue { .. } => { + Self::UserDefinedValue(UserDefinedValueType { type_id, semantic }) + } + types::Type::Void => Self::Void(VoidType { type_id, semantic }), + } + } + + pub fn type_id(&self) -> TypeId { + match self { + Type::Address(details) => details.type_id, + Type::Array(details) => details.type_id, + Type::Boolean(details) => details.type_id, + Type::ByteArray(details) => details.type_id, + Type::Bytes(details) => details.type_id, + Type::Contract(details) => details.type_id, + Type::Enum(details) => details.type_id, + Type::FixedPointNumber(details) => details.type_id, + Type::Function(details) => details.type_id, + Type::Integer(details) => details.type_id, + Type::Interface(details) => details.type_id, + Type::Literal(details) => details.type_id, + Type::Mapping(details) => details.type_id, + Type::String(details) => details.type_id, + Type::Struct(details) => details.type_id, + Type::Tuple(details) => details.type_id, + Type::UserDefinedValue(details) => details.type_id, + Type::Void(details) => details.type_id, + } + } +} + +impl AddressType { + pub fn payable(&self) -> bool { + let types::Type::Address { payable } = self.internal_type() else { + unreachable!("invalid address type"); + }; + *payable + } +} + +impl ArrayType { + pub fn element_type(&self) -> Type { + let types::Type::Array { element_type, .. } = self.internal_type() else { + unreachable!("invalid array type"); + }; + Type::create(*element_type, &self.semantic) + } + pub fn location(&self) -> DataLocation { + let types::Type::Array { location, .. } = self.internal_type() else { + unreachable!("invalid array type"); + }; + *location + } +} + +impl BooleanType {} + +impl ByteArrayType { + pub fn width(&self) -> u32 { + let types::Type::ByteArray { width } = self.internal_type() else { + unreachable!("invalid byte array type"); + }; + *width + } +} + +impl BytesType { + pub fn location(&self) -> DataLocation { + let types::Type::Bytes { location } = self.internal_type() else { + unreachable!("invalid bytes type"); + }; + *location + } +} + +impl ContractType { + pub fn definition(&self) -> Definition { + let types::Type::Contract { definition_id } = self.internal_type() else { + unreachable!("invalid contract type"); + }; + Definition::create(*definition_id, &self.semantic) + } +} + +impl EnumType { + pub fn definition(&self) -> Definition { + let types::Type::Enum { definition_id } = self.internal_type() else { + unreachable!("invalid enum type"); + }; + Definition::create(*definition_id, &self.semantic) + } +} + +impl FixedPointNumberType { + pub fn signed(&self) -> bool { + let types::Type::FixedPointNumber { signed, .. } = self.internal_type() else { + unreachable!("invalid fixed point number type"); + }; + *signed + } + pub fn bits(&self) -> u32 { + let types::Type::FixedPointNumber { bits, .. } = self.internal_type() else { + unreachable!("invalid fixed point number type"); + }; + *bits + } + pub fn precision_bits(&self) -> u32 { + let types::Type::FixedPointNumber { precision_bits, .. } = self.internal_type() else { + unreachable!("invalid fixed point number type"); + }; + *precision_bits + } +} + +impl FunctionType {} + +impl IntegerType { + pub fn signed(&self) -> bool { + let types::Type::Integer { signed, .. } = self.internal_type() else { + unreachable!("invalid integer type"); + }; + *signed + } + pub fn bits(&self) -> u32 { + let types::Type::Integer { bits, .. } = self.internal_type() else { + unreachable!("invalid integer type"); + }; + *bits + } +} + +impl InterfaceType { + pub fn definition(&self) -> Definition { + let types::Type::Interface { definition_id } = self.internal_type() else { + unreachable!("invalid interface type"); + }; + Definition::create(*definition_id, &self.semantic) + } +} + +impl LiteralType {} + +impl MappingType { + pub fn key_type(&self) -> Type { + let types::Type::Mapping { key_type_id, .. } = self.internal_type() else { + unreachable!("invalid mapping type"); + }; + Type::create(*key_type_id, &self.semantic) + } + pub fn value_type(&self) -> Type { + let types::Type::Mapping { value_type_id, .. } = self.internal_type() else { + unreachable!("invalid mapping type"); + }; + Type::create(*value_type_id, &self.semantic) + } +} + +impl StringType { + pub fn location(&self) -> DataLocation { + let types::Type::String { location } = self.internal_type() else { + unreachable!("invalid string type"); + }; + *location + } +} + +impl StructType { + pub fn definition(&self) -> Definition { + let types::Type::Struct { definition_id, .. } = self.internal_type() else { + unreachable!("invalid struct type"); + }; + Definition::create(*definition_id, &self.semantic) + } + pub fn location(&self) -> DataLocation { + let types::Type::Struct { location, .. } = self.internal_type() else { + unreachable!("invalid struct type"); + }; + *location + } +} + +impl TupleType { + pub fn types(&self) -> Vec { + let types::Type::Tuple { types } = self.internal_type() else { + unreachable!("invalid tuple type"); + }; + types.iter().map(|type_id| Type::create(*type_id, &self.semantic)).collect() + } +} + +impl UserDefinedValueType { + pub fn definition(&self) -> Definition { + let types::Type::UserDefinedValue { definition_id } = self.internal_type() else { + unreachable!("invalid user defined value type"); + }; + Definition::create(*definition_id, &self.semantic) + } +} + +impl VoidType {} diff --git a/crates/solidity/outputs/cargo/crate/src/backend/types/mod.rs b/crates/solidity/outputs/cargo/crate/src/backend/types/mod.rs index 559e97016b..585b3e1902 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/types/mod.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/types/mod.rs @@ -8,6 +8,7 @@ pub use registry::TypeRegistry; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub struct TypeId(usize); +// __SLANG_TYPE_TYPES__ keep in sync with AST types #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum Type { Address { From fa9d5cb87304e6454ca271694d306a7c0d5c79a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Fri, 23 Jan 2026 14:52:25 -0500 Subject: [PATCH 2/3] Expose types from the AST and add unit test --- .../ir/ast/node_extensions/identifiers.rs | 6 +- .../backend/ir/ast/node_extensions/types.rs | 13 +- .../src/backend/ir/ast/nodes.generated.rs | 390 +++++++++++++++++- .../crate/src/backend/ir/ast/nodes.rs.jinja2 | 5 + .../cargo/crate/src/backend/semantic/mod.rs | 11 +- .../cargo/tests/src/backend/semantic/ast.rs | 33 ++ 6 files changed, 449 insertions(+), 9 deletions(-) diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/identifiers.rs b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/identifiers.rs index 796f84393c..f33065dc41 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/identifiers.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/identifiers.rs @@ -1,7 +1,7 @@ use std::rc::Rc; use super::super::IdentifierPathStruct; -use super::Definition; +use super::{Definition, Type}; use crate::backend::SemanticAnalysis; use crate::cst::{NodeId, TerminalKind, TerminalNode}; @@ -56,6 +56,10 @@ impl IdentifierStruct { pub fn references(&self) -> Vec { self.semantic.references_binding_to(self.ir_node.id()) } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.id()) + } } pub type YulIdentifierStruct = IdentifierStruct; diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs index 62a875c828..7e8735aad1 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs @@ -1,12 +1,10 @@ -use paste::paste; use std::rc::Rc; -use crate::backend::{ - types::{self, DataLocation, TypeId}, - SemanticAnalysis, -}; +use paste::paste; use super::Definition; +use crate::backend::types::{self, DataLocation, TypeId}; +use crate::backend::SemanticAnalysis; // __SLANG_TYPE_TYPES__ keep in sync with binder types #[derive(Clone)] @@ -278,7 +276,10 @@ impl TupleType { let types::Type::Tuple { types } = self.internal_type() else { unreachable!("invalid tuple type"); }; - types.iter().map(|type_id| Type::create(*type_id, &self.semantic)).collect() + types + .iter() + .map(|type_id| Type::create(*type_id, &self.semantic)) + .collect() } } diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.generated.rs b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.generated.rs index b3066dbee0..457375bc79 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.generated.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.generated.rs @@ -5,11 +5,11 @@ use std::rc::Rc; use paste::paste; -use super::input as input_ir; use super::node_extensions::{ create_identifier, create_yul_identifier, Identifier, IdentifierStruct, YulIdentifier, YulIdentifierStruct, }; +use super::{input as input_ir, Type}; use crate::backend::{binder, SemanticAnalysis}; use crate::cst::{NodeId, TerminalKind, TerminalNode, TextIndex}; @@ -44,6 +44,10 @@ impl SourceUnitStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type PragmaDirective = Rc; @@ -73,6 +77,10 @@ impl PragmaDirectiveStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type AbicoderPragma = Rc; @@ -102,6 +110,10 @@ impl AbicoderPragmaStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ExperimentalPragma = Rc; @@ -131,6 +143,10 @@ impl ExperimentalPragmaStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type VersionPragma = Rc; @@ -160,6 +176,10 @@ impl VersionPragmaStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type VersionRange = Rc; @@ -193,6 +213,10 @@ impl VersionRangeStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type VersionTerm = Rc; @@ -229,6 +253,10 @@ impl VersionTermStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type PathImport = Rc; @@ -265,6 +293,10 @@ impl PathImportStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ImportDeconstruction = Rc; @@ -298,6 +330,10 @@ impl ImportDeconstructionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ImportDeconstructionSymbol = Rc; @@ -334,6 +370,10 @@ impl ImportDeconstructionSymbolStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type UsingDirective = Rc; @@ -371,6 +411,10 @@ impl UsingDirectiveStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type UsingDeconstruction = Rc; @@ -400,6 +444,10 @@ impl UsingDeconstructionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type UsingDeconstructionSymbol = Rc; @@ -436,6 +484,10 @@ impl UsingDeconstructionSymbolStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ContractDefinition = Rc; @@ -484,6 +536,10 @@ impl ContractDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type InheritanceType = Rc; @@ -520,6 +576,10 @@ impl InheritanceTypeStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type InterfaceDefinition = Rc; @@ -560,6 +620,10 @@ impl InterfaceDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type LibraryDefinition = Rc; @@ -593,6 +657,10 @@ impl LibraryDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type StructDefinition = Rc; @@ -626,6 +694,10 @@ impl StructDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type StructMember = Rc; @@ -659,6 +731,10 @@ impl StructMemberStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type EnumDefinition = Rc; @@ -692,6 +768,10 @@ impl EnumDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ConstantDefinition = Rc; @@ -739,6 +819,10 @@ impl ConstantDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type StateVariableDefinition = Rc; @@ -794,6 +878,10 @@ impl StateVariableDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type FunctionDefinition = Rc; @@ -871,6 +959,10 @@ impl FunctionDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type Parameter = Rc; @@ -918,6 +1010,10 @@ impl ParameterStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type OverrideSpecifier = Rc; @@ -950,6 +1046,10 @@ impl OverrideSpecifierStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ModifierInvocation = Rc; @@ -986,6 +1086,10 @@ impl ModifierInvocationStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type EventDefinition = Rc; @@ -1023,6 +1127,10 @@ impl EventDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type UserDefinedValueTypeDefinition = Rc; @@ -1056,6 +1164,10 @@ impl UserDefinedValueTypeDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ErrorDefinition = Rc; @@ -1089,6 +1201,10 @@ impl ErrorDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ArrayTypeName = Rc; @@ -1125,6 +1241,10 @@ impl ArrayTypeNameStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type FunctionType = Rc; @@ -1169,6 +1289,10 @@ impl FunctionTypeStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type MappingType = Rc; @@ -1202,6 +1326,10 @@ impl MappingTypeStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type AddressType = Rc; @@ -1231,6 +1359,10 @@ impl AddressTypeStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type Block = Rc; @@ -1257,6 +1389,10 @@ impl BlockStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type UncheckedBlock = Rc; @@ -1286,6 +1422,10 @@ impl UncheckedBlockStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ExpressionStatement = Rc; @@ -1315,6 +1455,10 @@ impl ExpressionStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type AssemblyStatement = Rc; @@ -1352,6 +1496,10 @@ impl AssemblyStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type TupleDeconstructionStatement = Rc; @@ -1385,6 +1533,10 @@ impl TupleDeconstructionStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type VariableDeclarationStatement = Rc; @@ -1435,6 +1587,10 @@ impl VariableDeclarationStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type IfStatement = Rc; @@ -1475,6 +1631,10 @@ impl IfStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ForStatement = Rc; @@ -1519,6 +1679,10 @@ impl ForStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type WhileStatement = Rc; @@ -1552,6 +1716,10 @@ impl WhileStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type DoWhileStatement = Rc; @@ -1585,6 +1753,10 @@ impl DoWhileStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ContinueStatement = Rc; @@ -1610,6 +1782,10 @@ impl ContinueStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type BreakStatement = Rc; @@ -1635,6 +1811,10 @@ impl BreakStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ReturnStatement = Rc; @@ -1667,6 +1847,10 @@ impl ReturnStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type EmitStatement = Rc; @@ -1700,6 +1884,10 @@ impl EmitStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type TryStatement = Rc; @@ -1744,6 +1932,10 @@ impl TryStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type CatchClause = Rc; @@ -1780,6 +1972,10 @@ impl CatchClauseStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type CatchClauseError = Rc; @@ -1816,6 +2012,10 @@ impl CatchClauseErrorStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type RevertStatement = Rc; @@ -1849,6 +2049,10 @@ impl RevertStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ThrowStatement = Rc; @@ -1874,6 +2078,10 @@ impl ThrowStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type AssignmentExpression = Rc; @@ -1911,6 +2119,10 @@ impl AssignmentExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ConditionalExpression = Rc; @@ -1948,6 +2160,10 @@ impl ConditionalExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type OrExpression = Rc; @@ -1981,6 +2197,10 @@ impl OrExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type AndExpression = Rc; @@ -2014,6 +2234,10 @@ impl AndExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type EqualityExpression = Rc; @@ -2051,6 +2275,10 @@ impl EqualityExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type InequalityExpression = Rc; @@ -2088,6 +2316,10 @@ impl InequalityExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type BitwiseOrExpression = Rc; @@ -2121,6 +2353,10 @@ impl BitwiseOrExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type BitwiseXorExpression = Rc; @@ -2154,6 +2390,10 @@ impl BitwiseXorExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type BitwiseAndExpression = Rc; @@ -2187,6 +2427,10 @@ impl BitwiseAndExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ShiftExpression = Rc; @@ -2224,6 +2468,10 @@ impl ShiftExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type AdditiveExpression = Rc; @@ -2261,6 +2509,10 @@ impl AdditiveExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type MultiplicativeExpression = Rc; @@ -2298,6 +2550,10 @@ impl MultiplicativeExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ExponentiationExpression = Rc; @@ -2335,6 +2591,10 @@ impl ExponentiationExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type PostfixExpression = Rc; @@ -2368,6 +2628,10 @@ impl PostfixExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type PrefixExpression = Rc; @@ -2401,6 +2665,10 @@ impl PrefixExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type FunctionCallExpression = Rc; @@ -2434,6 +2702,10 @@ impl FunctionCallExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type CallOptionsExpression = Rc; @@ -2467,6 +2739,10 @@ impl CallOptionsExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type MemberAccessExpression = Rc; @@ -2500,6 +2776,10 @@ impl MemberAccessExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type IndexAccessExpression = Rc; @@ -2543,6 +2823,10 @@ impl IndexAccessExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type NamedArgument = Rc; @@ -2576,6 +2860,10 @@ impl NamedArgumentStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type TypeExpression = Rc; @@ -2605,6 +2893,10 @@ impl TypeExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type NewExpression = Rc; @@ -2634,6 +2926,10 @@ impl NewExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type TupleExpression = Rc; @@ -2663,6 +2959,10 @@ impl TupleExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type TupleValue = Rc; @@ -2695,6 +2995,10 @@ impl TupleValueStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type ArrayExpression = Rc; @@ -2724,6 +3028,10 @@ impl ArrayExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type HexNumberExpression = Rc; @@ -2760,6 +3068,10 @@ impl HexNumberExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type DecimalNumberExpression = Rc; @@ -2796,6 +3108,10 @@ impl DecimalNumberExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulBlock = Rc; @@ -2825,6 +3141,10 @@ impl YulBlockStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulFunctionDefinition = Rc; @@ -2869,6 +3189,10 @@ impl YulFunctionDefinitionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulVariableDeclarationStatement = Rc; @@ -2905,6 +3229,10 @@ impl YulVariableDeclarationStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulVariableDeclarationValue = Rc; @@ -2938,6 +3266,10 @@ impl YulVariableDeclarationValueStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulVariableAssignmentStatement = Rc; @@ -2975,6 +3307,10 @@ impl YulVariableAssignmentStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulColonAndEqual = Rc; @@ -3000,6 +3336,10 @@ impl YulColonAndEqualStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulStackAssignmentStatement = Rc; @@ -3033,6 +3373,10 @@ impl YulStackAssignmentStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulEqualAndColon = Rc; @@ -3058,6 +3402,10 @@ impl YulEqualAndColonStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulIfStatement = Rc; @@ -3091,6 +3439,10 @@ impl YulIfStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulForStatement = Rc; @@ -3132,6 +3484,10 @@ impl YulForStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulSwitchStatement = Rc; @@ -3165,6 +3521,10 @@ impl YulSwitchStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulDefaultCase = Rc; @@ -3194,6 +3554,10 @@ impl YulDefaultCaseStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulValueCase = Rc; @@ -3227,6 +3591,10 @@ impl YulValueCaseStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulLeaveStatement = Rc; @@ -3252,6 +3620,10 @@ impl YulLeaveStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulBreakStatement = Rc; @@ -3277,6 +3649,10 @@ impl YulBreakStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulContinueStatement = Rc; @@ -3302,6 +3678,10 @@ impl YulContinueStatementStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulLabel = Rc; @@ -3331,6 +3711,10 @@ impl YulLabelStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } pub type YulFunctionCallExpression = Rc; @@ -3364,6 +3748,10 @@ impl YulFunctionCallExpressionStruct { .get_text_offset_by_node_id(self.ir_node.node_id) .unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } // diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.rs.jinja2 b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.rs.jinja2 index b0055a9791..bbc3464b9d 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.rs.jinja2 +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/nodes.rs.jinja2 @@ -11,6 +11,7 @@ use crate::backend::{binder, SemanticAnalysis}; use crate::cst::{NodeId, TerminalKind, TerminalNode, TextIndex}; use super::node_extensions::{Identifier, IdentifierStruct, YulIdentifier, YulIdentifierStruct}; use super::node_extensions::{create_identifier, create_yul_identifier}; +use super::Type; // // Sequences: @@ -96,6 +97,10 @@ use super::node_extensions::{create_identifier, create_yul_identifier}; pub fn text_offset(&self) -> TextIndex { self.semantic.get_text_offset_by_node_id(self.ir_node.node_id).unwrap() } + + pub fn get_type(&self) -> Option { + self.semantic.get_type_from_node_id(self.ir_node.node_id) + } } {% endfor %} diff --git a/crates/solidity/outputs/cargo/crate/src/backend/semantic/mod.rs b/crates/solidity/outputs/cargo/crate/src/backend/semantic/mod.rs index 667797178a..d194f02ea7 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/semantic/mod.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/semantic/mod.rs @@ -3,7 +3,9 @@ use std::rc::Rc; use semver::Version; -use self::ast::{create_contract_definition, create_source_unit, ContractDefinition, Definition}; +use self::ast::{ + create_contract_definition, create_source_unit, ContractDefinition, Definition, Type, +}; use crate::backend::binder::Binder; pub use crate::backend::ir::{ast, ir2_flat_contracts as output_ir}; use crate::backend::types::TypeRegistry; @@ -195,4 +197,11 @@ impl SemanticAnalysis { pub(crate) fn get_text_offset_by_node_id(&self, node_id: NodeId) -> Option { self.text_offsets.get(&node_id).copied() } + + pub fn get_type_from_node_id(self: &Rc, node_id: NodeId) -> Option { + self.binder + .node_typing(node_id) + .as_type_id() + .map(|type_id| Type::create(type_id, self)) + } } diff --git a/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs b/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs index a89a8844f1..838a766617 100644 --- a/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs +++ b/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs @@ -89,3 +89,36 @@ fn test_text_offsets() -> Result<()> { Ok(()) } + +#[test] +fn test_get_type() -> Result<()> { + let unit = build_compilation_unit()?; + let semantic = unit.semantic_analysis(); + + let ownable = semantic + .find_contract_by_name("Ownable") + .expect("contract is found"); + + let state_variables = ownable + .members() + .iter() + .filter_map(|member| { + if let ast::ContractMember::StateVariableDefinition(definition) = member { + Some(definition) + } else { + None + } + }) + .collect::>(); + + assert_eq!(state_variables.len(), 1); + let owner = &state_variables[0]; + assert_eq!(owner.name().unparse(), "_owner"); + + let owner_type = owner + .get_type() + .expect("_owner state variable has resolved type"); + assert!(matches!(owner_type, ast::Type::Address(_))); + + Ok(()) +} From 2b2792f45797e41d1ee01748ec012f4302cf4d54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Mon, 26 Jan 2026 13:41:38 -0500 Subject: [PATCH 3/3] Add attribute getters for `Type::Function` --- .../backend/ir/ast/node_extensions/types.rs | 54 ++++++++++++++++++- .../cargo/tests/src/backend/semantic/ast.rs | 41 ++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs index 7e8735aad1..9d04570cb4 100644 --- a/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs +++ b/crates/solidity/outputs/cargo/crate/src/backend/ir/ast/node_extensions/types.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use paste::paste; use super::Definition; -use crate::backend::types::{self, DataLocation, TypeId}; +use crate::backend::types::{self, DataLocation, FunctionTypeKind, TypeId}; use crate::backend::SemanticAnalysis; // __SLANG_TYPE_TYPES__ keep in sync with binder types @@ -204,7 +204,57 @@ impl FixedPointNumberType { } } -impl FunctionType {} +impl FunctionType { + pub fn associated_definition(&self) -> Option { + let types::Type::Function(function_type) = self.internal_type() else { + unreachable!("invalid function type"); + }; + function_type + .definition_id + .map(|definition_id| Definition::create(definition_id, &self.semantic)) + } + + pub fn implicit_receiver_type(&self) -> Option { + let types::Type::Function(function_type) = self.internal_type() else { + unreachable!("invalid function type"); + }; + function_type + .implicit_receiver_type + .map(|type_id| Type::create(type_id, &self.semantic)) + } + + pub fn parameter_types(&self) -> Vec { + let types::Type::Function(function_type) = self.internal_type() else { + unreachable!("invalid function type"); + }; + function_type + .parameter_types + .iter() + .map(|type_id| Type::create(*type_id, &self.semantic)) + .collect() + } + + pub fn return_type(&self) -> Type { + let types::Type::Function(function_type) = self.internal_type() else { + unreachable!("invalid function type"); + }; + Type::create(function_type.return_type, &self.semantic) + } + + pub fn external(&self) -> bool { + let types::Type::Function(function_type) = self.internal_type() else { + unreachable!("invalid function type"); + }; + function_type.external + } + + pub fn kind(&self) -> FunctionTypeKind { + let types::Type::Function(function_type) = self.internal_type() else { + unreachable!("invalid function type"); + }; + function_type.kind + } +} impl IntegerType { pub fn signed(&self) -> bool { diff --git a/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs b/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs index 838a766617..2ba1e11af7 100644 --- a/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs +++ b/crates/solidity/outputs/cargo/tests/src/backend/semantic/ast.rs @@ -122,3 +122,44 @@ fn test_get_type() -> Result<()> { Ok(()) } + +#[test] +fn test_function_get_type() -> Result<()> { + let unit = build_compilation_unit()?; + let semantic = unit.semantic_analysis(); + + let counter = semantic + .find_contract_by_name("Counter") + .expect("contract is found"); + + let increment = counter + .members() + .iter() + .find_map(|member| { + if let ast::ContractMember::FunctionDefinition(function_definition) = member { + if function_definition + .name() + .is_some_and(|name| name.unparse() == "increment") + { + Some(function_definition) + } else { + None + } + } else { + None + } + }) + .expect("increment method is found"); + + let increment_type = increment.get_type().expect("increment method has a type"); + let ast::Type::Function(function_type) = increment_type else { + panic!("method's type is expect to be a function"); + }; + assert!(function_type.external()); + assert!(matches!(function_type.return_type(), ast::Type::Integer(_))); + assert!(function_type + .associated_definition() + .is_some_and(|definition| matches!(definition, ast::Definition::Function(_)))); + + Ok(()) +}