diff --git a/Cargo.lock b/Cargo.lock index 4dc8abc86..1475b65e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5970,6 +5970,7 @@ version = "1.7.0" dependencies = [ "anyhow", "arbitrary", + "cairo-lang-starknet-classes", "criterion", "katana-metrics", "katana-primitives", @@ -6000,6 +6001,7 @@ dependencies = [ "anyhow", "assert_matches", "blockifier 0.0.0 (git+https://github.com/dojoengine/sequencer?rev=5d737b9c9)", + "cairo-lang-starknet-classes", "cairo-native", "cairo-vm 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", "criterion", @@ -6344,7 +6346,9 @@ dependencies = [ "assert_matches", "blockifier 0.0.0 (git+https://github.com/dojoengine/sequencer?rev=5d737b9c9)", "cainome-cairo-serde", + "cairo-lang-sierra", "cairo-lang-starknet-classes", + "cairo-lang-utils", "cairo-vm 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", "criterion", "derive_more 0.99.20", @@ -6510,7 +6514,6 @@ dependencies = [ "cainome-cairo-serde", "cairo-lang-starknet-classes", "cairo-lang-utils", - "derive_more 0.99.20", "flate2", "katana-genesis", "katana-primitives", diff --git a/crates/executor/Cargo.toml b/crates/executor/Cargo.toml index 26f9a407e..789481b8f 100644 --- a/crates/executor/Cargo.toml +++ b/crates/executor/Cargo.toml @@ -19,6 +19,7 @@ thiserror.workspace = true tracing.workspace = true # cairo-native +cairo-lang-starknet-classes = { workspace = true, optional = true } cairo-native = { version = "0.4.1", optional = true } cairo-vm.workspace = true parking_lot.workspace = true @@ -49,7 +50,7 @@ pprof.workspace = true rayon.workspace = true [features] -native = [ "blockifier/cairo_native", "dep:cairo-native", "dep:rayon" ] +native = [ "blockifier/cairo_native", "dep:cairo-native", "dep:rayon", "dep:cairo-lang-starknet-classes" ] [[bench]] harness = false diff --git a/crates/executor/src/implementation/blockifier/cache.rs b/crates/executor/src/implementation/blockifier/cache.rs index 94ff648fc..b99a5351d 100644 --- a/crates/executor/src/implementation/blockifier/cache.rs +++ b/crates/executor/src/implementation/blockifier/cache.rs @@ -250,7 +250,12 @@ impl ClassCache { use cairo_native::OptLevel; #[cfg(feature = "native")] - let program = sierra.extract_sierra_program().unwrap(); + let program = cairo_lang_starknet_classes::contract_class::ContractClass::from( + sierra.clone(), // TODO: avoid cloning here + ) + .extract_sierra_program() + .unwrap(); + #[cfg(feature = "native")] let entry_points = sierra.entry_points_by_type.clone(); diff --git a/crates/gateway/gateway-types/src/lib.rs b/crates/gateway/gateway-types/src/lib.rs index d2d91043a..e1079f67e 100644 --- a/crates/gateway/gateway-types/src/lib.rs +++ b/crates/gateway/gateway-types/src/lib.rs @@ -27,7 +27,7 @@ use katana_primitives::class::{ClassHash, CompiledClassHash}; use katana_primitives::contract::{Nonce, StorageKey, StorageValue}; use katana_primitives::da::L1DataAvailabilityMode; use katana_primitives::{ContractAddress, Felt}; -pub use katana_rpc_types::class::SierraClass; +pub use katana_rpc_types::class::RpcSierraContractClass; use serde::{Deserialize, Serialize}; use starknet::core::types::ResourcePrice; diff --git a/crates/gateway/gateway-types/src/transaction.rs b/crates/gateway/gateway-types/src/transaction.rs index ae24a88ef..6ffbe9aa4 100644 --- a/crates/gateway/gateway-types/src/transaction.rs +++ b/crates/gateway/gateway-types/src/transaction.rs @@ -24,7 +24,6 @@ use katana_primitives::transaction::{ TxWithHash, }; use katana_primitives::{ContractAddress, Felt}; -use katana_rpc_types::SierraClassAbi; use serde::{Deserialize, Deserializer, Serialize}; /// API response for an INVOKE_FUNCTION transaction @@ -257,13 +256,15 @@ pub struct CompressedSierraClass { pub sierra_program: CompressedSierraProgram, pub contract_class_version: String, pub entry_points_by_type: ContractEntryPoints, - pub abi: SierraClassAbi, + pub abi: Option, } -impl TryFrom for CompressedSierraClass { +impl TryFrom for CompressedSierraClass { type Error = CompressedSierraProgramError; - fn try_from(value: katana_rpc_types::class::SierraClass) -> Result { + fn try_from( + value: katana_rpc_types::class::RpcSierraContractClass, + ) -> Result { let abi = value.abi; let entry_points_by_type = value.entry_points_by_type; let contract_class_version = value.contract_class_version; @@ -272,7 +273,7 @@ impl TryFrom for CompressedSierraClass { } } -impl TryFrom for katana_rpc_types::class::SierraClass { +impl TryFrom for katana_rpc_types::class::RpcSierraContractClass { type Error = CompressedSierraProgramError; fn try_from(value: CompressedSierraClass) -> Result { @@ -724,7 +725,7 @@ mod tests { use katana_primitives::fee::{ResourceBounds, ResourceBoundsMapping, Tip}; use katana_primitives::{address, Felt}; - use katana_rpc_types::SierraClass; + use katana_rpc_types::RpcSierraContractClass; use super::*; @@ -772,7 +773,7 @@ mod tests { #[test] fn test_conversion_from_rpc_query_declare_tx() { - let sierra_class = Arc::new(katana_rpc_types::class::SierraClass { + let sierra_class = Arc::new(katana_rpc_types::class::RpcSierraContractClass { sierra_program: vec![Felt::from(0x123), Felt::from(0x456)], contract_class_version: "0.1.0".to_string(), entry_points_by_type: Default::default(), @@ -820,7 +821,7 @@ mod tests { assert_eq!(gateway_tx.fee_data_availability_mode, rpc_tx.fee_data_availability_mode.into()); // convert the gateway contract class to rpc contract class and ensure they are equal - let converted_sierra_class: SierraClass = + let converted_sierra_class: RpcSierraContractClass = gateway_tx.contract_class.as_ref().clone().try_into().unwrap(); assert_eq!(converted_sierra_class, sierra_class.as_ref().clone()); } diff --git a/crates/primitives/Cargo.toml b/crates/primitives/Cargo.toml index c7fce089d..33cd162ca 100644 --- a/crates/primitives/Cargo.toml +++ b/crates/primitives/Cargo.toml @@ -7,12 +7,15 @@ version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +cairo-vm.workspace = true +cairo-lang-sierra.workspace = true +cairo-lang-utils.workspace = true +cairo-lang-starknet-classes.workspace = true + anyhow.workspace = true arbitrary = { workspace = true, optional = true } blockifier = { workspace = true, features = [ "testing" ] } # some Clone derives are gated behind 'testing' feature cainome-cairo-serde.workspace = true -cairo-lang-starknet-classes.workspace = true -cairo-vm.workspace = true derive_more.workspace = true heapless = { version = "0.8.0", features = [ "serde" ] } lazy_static.workspace = true diff --git a/crates/primitives/src/class.rs b/crates/primitives/src/class.rs index 1d8bcad1a..3bc2adcda 100644 --- a/crates/primitives/src/class.rs +++ b/crates/primitives/src/class.rs @@ -1,10 +1,10 @@ use std::str::FromStr; -use cairo_lang_starknet_classes::abi; use cairo_lang_starknet_classes::casm_contract_class::StarknetSierraCompilationError; use cairo_lang_starknet_classes::contract_class::{ version_id_from_serialized_sierra_program, ContractEntryPoint, ContractEntryPoints, }; +use cairo_lang_utils::bigint::BigUintAsHex; use serde_json_pythonic::to_string_pythonic; use starknet::macros::short_string; use starknet_api::contract_class::SierraVersion; @@ -18,8 +18,6 @@ pub type ClassHash = Felt; /// The hash of a compiled contract class. pub type CompiledClassHash = Felt; -/// The canonical contract class (Sierra) type. -pub type SierraContractClass = cairo_lang_starknet_classes::contract_class::ContractClass; /// The canonical legacy class (Cairo 0) type. pub type LegacyContractClass = starknet_api::deprecated_contract_class::ContractClass; @@ -29,6 +27,97 @@ pub type CasmContractClass = cairo_lang_starknet_classes::casm_contract_class::C /// ABI for Sierra-based classes. pub type ContractAbi = cairo_lang_starknet_classes::abi::Contract; +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize), serde(untagged))] +pub enum MaybeInvalidSierraContractAbi { + Valid(ContractAbi), + Invalid(String), +} + +impl std::fmt::Display for MaybeInvalidSierraContractAbi { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MaybeInvalidSierraContractAbi::Valid(abi) => { + let s = to_string_pythonic(abi).expect("failed to serialize abi"); + write!(f, "{}", s) + } + MaybeInvalidSierraContractAbi::Invalid(abi) => write!(f, "{}", abi), + } + } +} + +impl From for MaybeInvalidSierraContractAbi { + fn from(value: String) -> Self { + match serde_json::from_str::(&value) { + Ok(abi) => MaybeInvalidSierraContractAbi::Valid(abi), + Err(..) => MaybeInvalidSierraContractAbi::Invalid(value), + } + } +} + +impl From<&str> for MaybeInvalidSierraContractAbi { + fn from(value: &str) -> Self { + match serde_json::from_str::(value) { + Ok(abi) => MaybeInvalidSierraContractAbi::Valid(abi), + Err(..) => MaybeInvalidSierraContractAbi::Invalid(value.to_string()), + } + } +} + +/// Represents a contract in the Starknet network. +/// +/// The canonical contract class (Sierra) type. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] +pub struct SierraContractClass { + pub sierra_program: Vec, + pub sierra_program_debug_info: Option, + pub contract_class_version: String, + pub entry_points_by_type: ContractEntryPoints, + pub abi: Option, +} + +impl SierraContractClass { + /// Computes the hash of the Sierra contract class. + pub fn hash(&self) -> ClassHash { + let Self { sierra_program, abi, entry_points_by_type, .. } = self; + + let program: Vec = sierra_program.iter().map(|f| f.value.clone().into()).collect(); + let abi: String = abi.as_ref().map(|abi| abi.to_string()).unwrap_or_default(); + + compute_sierra_class_hash(&abi, entry_points_by_type, &program) + } +} + +impl From for cairo_lang_starknet_classes::contract_class::ContractClass { + fn from(value: SierraContractClass) -> Self { + let abi = value.abi.and_then(|abi| match abi { + MaybeInvalidSierraContractAbi::Invalid(..) => None, + MaybeInvalidSierraContractAbi::Valid(abi) => Some(abi), + }); + + cairo_lang_starknet_classes::contract_class::ContractClass { + abi, + sierra_program: value.sierra_program, + entry_points_by_type: value.entry_points_by_type, + contract_class_version: value.contract_class_version, + sierra_program_debug_info: value.sierra_program_debug_info, + } + } +} + +impl From for SierraContractClass { + fn from(value: cairo_lang_starknet_classes::contract_class::ContractClass) -> Self { + SierraContractClass { + abi: value.abi.map(MaybeInvalidSierraContractAbi::Valid), + sierra_program: value.sierra_program, + entry_points_by_type: value.entry_points_by_type, + contract_class_version: value.contract_class_version, + sierra_program_debug_info: value.sierra_program_debug_info, + } + } +} + #[derive(Debug, thiserror::Error)] pub enum ContractClassCompilationError { #[error(transparent)] @@ -51,23 +140,7 @@ impl ContractClass { /// Computes the hash of the class. pub fn class_hash(&self) -> Result { match self { - Self::Class(class) => { - // Technically we don't have to use the Pythonic JSON style here. Doing this just to - // align with the official `cairo-lang` CLI. - // - // TODO: add an `AbiFormatter` trait and let users choose which one to use. - let abi = class.abi.as_ref(); - let abi_str = to_string_pythonic(abi.unwrap_or(&abi::Contract::default())).unwrap(); - - let sierra_program = &class - .sierra_program - .iter() - .map(|f| f.value.clone().into()) - .collect::>(); - - Ok(compute_sierra_class_hash(&abi_str, &class.entry_points_by_type, sierra_program)) - } - + Self::Class(class) => Ok(class.hash()), Self::Legacy(class) => compute_legacy_class_hash(class), } } @@ -77,7 +150,7 @@ impl ContractClass { match self { Self::Legacy(class) => Ok(CompiledClass::Legacy(class)), Self::Class(class) => { - let casm = CasmContractClass::from_contract_class(class, true, usize::MAX)?; + let casm = CasmContractClass::from_contract_class(class.into(), true, usize::MAX)?; let casm = CompiledClass::Class(casm); Ok(casm) } diff --git a/crates/rpc/rpc-server/tests/starknet.rs b/crates/rpc/rpc-server/tests/starknet.rs index fa6e152dd..bd7759e52 100644 --- a/crates/rpc/rpc-server/tests/starknet.rs +++ b/crates/rpc/rpc-server/tests/starknet.rs @@ -166,9 +166,9 @@ async fn get_compiled_casm() { // Setup expected compiled class data to verify against use katana_primitives::class::{ContractClass, SierraContractClass}; - use katana_rpc_types::SierraClass as RpcSierraClass; + use katana_rpc_types::RpcSierraContractClass; - let rpc_class = RpcSierraClass::try_from(contract).unwrap(); + let rpc_class = RpcSierraContractClass::try_from(contract).unwrap(); let class = SierraContractClass::try_from(rpc_class).unwrap(); let expected_casm = ContractClass::Class(class).compile().unwrap(); diff --git a/crates/rpc/rpc-types/Cargo.toml b/crates/rpc/rpc-types/Cargo.toml index f570a53aa..5a71a4fac 100644 --- a/crates/rpc/rpc-types/Cargo.toml +++ b/crates/rpc/rpc-types/Cargo.toml @@ -12,7 +12,6 @@ katana-genesis.workspace = true katana-trie.workspace = true serde-utils.workspace = true -derive_more.workspace = true cainome.workspace = true cainome-cairo-serde.workspace = true cairo-lang-starknet-classes.workspace = true diff --git a/crates/rpc/rpc-types/src/broadcasted.rs b/crates/rpc/rpc-types/src/broadcasted.rs index 7d416aae6..c384fad43 100644 --- a/crates/rpc/rpc-types/src/broadcasted.rs +++ b/crates/rpc/rpc-types/src/broadcasted.rs @@ -15,7 +15,7 @@ use katana_primitives::utils::get_contract_address; use katana_primitives::{ContractAddress, Felt}; use serde::{de, Deserialize, Deserializer, Serialize}; -use crate::class::SierraClass; +use crate::class::RpcSierraContractClass; pub const QUERY_VERSION_OFFSET: Felt = Felt::from_raw([576460752142434320, 18446744073709551584, 17407, 18446744073700081665]); @@ -71,7 +71,7 @@ pub struct UntypedBroadcastedTx { pub compiled_class_hash: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub contract_class: Option>, + pub contract_class: Option>, // Invoke & Declare only field #[serde(default, skip_serializing_if = "Option::is_none")] @@ -476,7 +476,7 @@ pub struct BroadcastedDeclareTx { /// a transaction to be valid for execution, the nonce must be equal to the account's current /// nonce. pub nonce: Nonce, - pub contract_class: Arc, + pub contract_class: Arc, /// Data needed to allow the paymaster to pay for the transaction in native tokens. pub paymaster_data: Vec, /// The tip for the transaction. diff --git a/crates/rpc/rpc-types/src/class.rs b/crates/rpc/rpc-types/src/class.rs index 83132f61b..d4e69db86 100644 --- a/crates/rpc/rpc-types/src/class.rs +++ b/crates/rpc/rpc-types/src/class.rs @@ -13,7 +13,6 @@ use katana_primitives::{ Felt, {self}, }; use serde::{Deserialize, Serialize}; -use serde_json_pythonic::to_string_pythonic; use serde_utils::base64; use starknet::core::types::{CompressedLegacyContractClass, FlattenedSierraClass}; use starknet_api::contract_class::EntryPointType; @@ -31,8 +30,8 @@ pub type CasmClass = katana_primitives::class::CompiledClass; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum Class { - Sierra(SierraClass), - Legacy(LegacyClass), + Sierra(RpcSierraContractClass), + Legacy(RpcLegacyContractClass), } #[derive(Debug, thiserror::Error)] @@ -49,58 +48,32 @@ pub enum ConversionError { // -- SIERRA CLASS -/// The ABI is serialized as Pythonic JSON. -#[derive(Debug, Clone, PartialEq, Eq, Default, derive_more::Deref, derive_more::From)] -pub struct SierraClassAbi(katana_primitives::class::ContractAbi); - -impl std::fmt::Display for SierraClassAbi { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = to_string_pythonic(&self.0).map_err(|_| std::fmt::Error)?; - write!(f, "{s}") - } -} - -impl Serialize for SierraClassAbi { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for SierraClassAbi { - fn deserialize>(deserializer: D) -> Result { - let str = String::deserialize(deserializer)?; - serde_json::from_str::(&str) - .map(Self) - .map_err(|e| serde::de::Error::custom(format!("invalid abi format: {e}"))) - } -} - #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] -pub struct SierraClass { +pub struct RpcSierraContractClass { pub sierra_program: Vec, pub contract_class_version: String, pub entry_points_by_type: ContractEntryPoints, - pub abi: SierraClassAbi, + pub abi: Option, } ////////////////////////////////////////////////// // SierraClass implementations ////////////////////////////////////////////////// -impl SierraClass { +impl RpcSierraContractClass { /// Computes the hash of the Sierra class. pub fn hash(&self) -> ClassHash { compute_sierra_class_hash( - &self.abi.to_string(), + self.abi.as_deref().unwrap_or(""), &self.entry_points_by_type, &self.sierra_program, ) } } -impl From for SierraClass { +impl From for RpcSierraContractClass { fn from(value: SierraContractClass) -> Self { - let abi = value.abi.map(SierraClassAbi).unwrap_or_default(); + let abi = value.abi.map(|abi| abi.to_string()); let program = value.sierra_program.into_iter().map(|f| f.value.into()).collect::>(); Self { @@ -112,16 +85,18 @@ impl From for SierraClass { } } -impl From for SierraContractClass { - fn from(value: SierraClass) -> Self { +impl From for SierraContractClass { + fn from(value: RpcSierraContractClass) -> Self { let program = value .sierra_program .into_iter() .map(|f| BigUintAsHex { value: f.to_biguint() }) .collect::>(); + let abi = value.abi.map(|abi| abi.into()); + Self { - abi: value.abi.0.into(), + abi, sierra_program: program, sierra_program_debug_info: None, entry_points_by_type: value.entry_points_by_type, @@ -133,7 +108,7 @@ impl From for SierraContractClass { // -- LEGACY CLASS #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LegacyClass { +pub struct RpcLegacyContractClass { /// A base64 representation of the compressed program code #[serde(with = "base64")] pub program: Vec, @@ -148,7 +123,7 @@ pub struct LegacyClass { // LegacyClass implementations ////////////////////////////////////////////////// -impl TryFrom for LegacyClass { +impl TryFrom for RpcLegacyContractClass { type Error = ConversionError; fn try_from(value: LegacyContractClass) -> Result { @@ -157,10 +132,10 @@ impl TryFrom for LegacyClass { } } -impl TryFrom for LegacyContractClass { +impl TryFrom for LegacyContractClass { type Error = ConversionError; - fn try_from(value: LegacyClass) -> Result { + fn try_from(value: RpcLegacyContractClass) -> Result { let program = decompress_legacy_program(&value.program)?; Ok(Self { program, abi: value.abi, entry_points_by_type: value.entry_points_by_type }) } @@ -190,8 +165,10 @@ impl TryFrom for Class { fn try_from(value: ContractClass) -> Result { match value { - ContractClass::Class(class) => Ok(Self::Sierra(SierraClass::from(class))), - ContractClass::Legacy(class) => Ok(Self::Legacy(LegacyClass::try_from(class)?)), + ContractClass::Class(class) => Ok(Self::Sierra(RpcSierraContractClass::from(class))), + ContractClass::Legacy(class) => { + Ok(Self::Legacy(RpcLegacyContractClass::try_from(class)?)) + } } } } @@ -220,10 +197,10 @@ impl TryFrom for Class { fn try_from(value: starknet::core::types::ContractClass) -> Result { match value { starknet::core::types::ContractClass::Legacy(class) => { - Ok(Self::Legacy(LegacyClass::try_from(class)?)) + Ok(Self::Legacy(RpcLegacyContractClass::try_from(class)?)) } starknet::core::types::ContractClass::Sierra(class) => { - Ok(Self::Sierra(SierraClass::try_from(class)?)) + Ok(Self::Sierra(RpcSierraContractClass::try_from(class)?)) } } } @@ -242,7 +219,7 @@ impl TryFrom for starknet::core::types::ContractClass { } } -impl TryFrom for SierraClass { +impl TryFrom for RpcSierraContractClass { type Error = ConversionError; fn try_from(value: FlattenedSierraClass) -> Result { @@ -252,17 +229,17 @@ impl TryFrom for SierraClass { } } -impl TryFrom for FlattenedSierraClass { +impl TryFrom for FlattenedSierraClass { type Error = ConversionError; - fn try_from(value: SierraClass) -> Result { + fn try_from(value: RpcSierraContractClass) -> Result { let value = serde_json::to_value(value)?; let class = serde_json::from_value::(value)?; Ok(class) } } -impl TryFrom for LegacyClass { +impl TryFrom for RpcLegacyContractClass { type Error = ConversionError; fn try_from(value: CompressedLegacyContractClass) -> Result { @@ -272,10 +249,10 @@ impl TryFrom for LegacyClass { } } -impl TryFrom for CompressedLegacyContractClass { +impl TryFrom for CompressedLegacyContractClass { type Error = ConversionError; - fn try_from(value: LegacyClass) -> Result { + fn try_from(value: RpcLegacyContractClass) -> Result { let value = serde_json::to_value(value)?; let class = serde_json::from_value::(value)?; Ok(class) @@ -284,12 +261,16 @@ impl TryFrom for CompressedLegacyContractClass { #[cfg(test)] mod tests { - use katana_primitives::class::{ContractClass, LegacyContractClass, SierraContractClass}; + use katana_primitives::class::{ + ContractAbi, ContractClass, LegacyContractClass, MaybeInvalidSierraContractAbi, + SierraContractClass, + }; + use serde_json::json; use starknet::core::types::contract::legacy::LegacyContractClass as StarknetRsLegacyContractClass; use starknet::core::types::contract::SierraClass as StarknetRsSierraClass; - use super::LegacyClass; - use crate::class::SierraClass; + use super::RpcLegacyContractClass; + use crate::class::RpcSierraContractClass; #[test] fn rt() { @@ -297,7 +278,7 @@ mod tests { include_str!("../../../contracts/build/katana_account_Account.contract_class.json"); let class = serde_json::from_str::(json).unwrap(); - let rpc = SierraClass::try_from(class.clone()).unwrap(); + let rpc = RpcSierraContractClass::try_from(class.clone()).unwrap(); let rt = SierraContractClass::try_from(rpc).unwrap(); assert_eq!(class.abi, rt.abi); @@ -311,7 +292,7 @@ mod tests { let json = include_str!("../../../contracts/build/legacy/account.json"); let class = serde_json::from_str::(json).unwrap(); - let rpc = LegacyClass::try_from(class.clone()).unwrap(); + let rpc = RpcLegacyContractClass::try_from(class.clone()).unwrap(); let rt = LegacyContractClass::try_from(rpc).unwrap(); assert_eq!(class.abi, rt.abi); @@ -342,7 +323,7 @@ mod tests { // -- katana - let rpc = SierraClass::try_from(starknet_rpc).unwrap(); + let rpc = RpcSierraContractClass::try_from(starknet_rpc).unwrap(); let class = SierraContractClass::try_from(rpc).unwrap(); let hash = ContractClass::Class(class.clone()).class_hash().unwrap(); @@ -371,7 +352,7 @@ mod tests { // -- katana - let rpc = serde_json::from_str::(&json).unwrap(); + let rpc = serde_json::from_str::(&json).unwrap(); let class = LegacyContractClass::try_from(rpc).unwrap(); let hash = ContractClass::Legacy(class.clone()).class_hash().unwrap(); @@ -396,7 +377,7 @@ mod tests { include_str!("../../../contracts/build/katana_account_Account.contract_class.json"); let class = serde_json::from_str::(json).unwrap(); - let rpc_class = SierraClass::try_from(class.clone()).unwrap(); + let rpc_class = RpcSierraContractClass::try_from(class.clone()).unwrap(); let rpc_class_hash = rpc_class.hash(); let primitive = ContractClass::Class(SierraContractClass::try_from(rpc_class).unwrap()); @@ -404,4 +385,97 @@ mod tests { assert_eq!(rpc_class_hash, primitive_class_hash); } + + #[test] + fn rpc_sierra_class_with_valid_abi() { + let raw_abi_string = + "[{\"type\": \"constructor\", \"name\": \"constructor\", \"inputs\": []}, {\"type\": \ + \"function\", \"name\": \"Ekubo\", \"inputs\": [], \"outputs\": [], \ + \"state_mutability\": \"external\"}, {\"type\": \"event\", \"name\": \ + \"ekubo_spammer::EkuboSpammer::EkuboEvent\", \"kind\": \"struct\", \"members\": []}, \ + {\"type\": \"event\", \"name\": \"ekubo_spammer::EkuboSpammer::Event\", \"kind\": \ + \"enum\", \"variants\": [{\"name\": \"Ekubo\", \"type\": \ + \"ekubo_spammer::EkuboSpammer::EkuboEvent\", \"kind\": \"nested\"}]}]"; + + let json = json!({ + "sierra_program": ["0x1", "0x2"], + "contract_class_version": "0.1.0", + "entry_points_by_type": { + "EXTERNAL": [], + "L1_HANDLER": [], + "CONSTRUCTOR": [] + }, + "abi": &raw_abi_string + }); + + let rpc_class = serde_json::from_value::(json).unwrap(); + let rpc_class_hash = rpc_class.hash(); + + assert_eq!(rpc_class.abi, Some(raw_abi_string.to_string())); + + // convert to primitive and ensure the conversion preserves the exact ABI string + let primitive_class = SierraContractClass::from(rpc_class); + let primitive_class_hash = primitive_class.hash(); + + let exected_abi = serde_json::from_str::(raw_abi_string).unwrap(); + + assert_eq!(rpc_class_hash, primitive_class_hash); + assert_eq!(primitive_class.abi, Some(MaybeInvalidSierraContractAbi::Valid(exected_abi))); + } + + #[test] + fn rpc_sierra_class_with_arbitrary_abi_str() { + let raw_abi_string = "hello world!"; + + let json = json!({ + "sierra_program": ["0x1", "0x2"], + "contract_class_version": "0.1.0", + "entry_points_by_type": { + "EXTERNAL": [], + "L1_HANDLER": [], + "CONSTRUCTOR": [] + }, + "abi": &raw_abi_string + }); + + let rpc_class = serde_json::from_value::(json).unwrap(); + let rpc_class_hash = rpc_class.hash(); + + assert_eq!(rpc_class.abi, Some(raw_abi_string.to_string())); + + // convert to primitive and ensure the conversion preserves the exact ABI string + let primitive_class = SierraContractClass::from(rpc_class); + let primitive_class_hash = primitive_class.hash(); + + assert_eq!(rpc_class_hash, primitive_class_hash); + assert_eq!( + primitive_class.abi, + Some(MaybeInvalidSierraContractAbi::Invalid(raw_abi_string.to_string())) + ); + } + + #[test] + fn rpc_sierra_class_with_empty_abi() { + let json = json!({ + "sierra_program": ["0x1", "0x2"], + "contract_class_version": "0.1.0", + "entry_points_by_type": { + "EXTERNAL": [], + "L1_HANDLER": [], + "CONSTRUCTOR": [] + }, + }); + + let rpc_class = serde_json::from_value::(json).unwrap(); + let rpc_class_hash = rpc_class.hash(); + + assert!(rpc_class.abi.is_none()); + + // convert to primitive and ensure the conversion preserves the exact ABI string + let primitive_class = SierraContractClass::from(rpc_class); + let primitive_class_hash = primitive_class.hash(); + + assert_eq!(rpc_class_hash, primitive_class_hash); + assert_eq!(primitive_class.abi, None); + } } diff --git a/crates/storage/db/Cargo.toml b/crates/storage/db/Cargo.toml index 66f5713b4..315bcb04e 100644 --- a/crates/storage/db/Cargo.toml +++ b/crates/storage/db/Cargo.toml @@ -12,6 +12,7 @@ katana-trie.workspace = true katana-metrics.workspace = true anyhow.workspace = true +cairo-lang-starknet-classes.workspace = true arbitrary = { workspace = true, optional = true } metrics.workspace = true page_size = "0.6.0" diff --git a/crates/storage/db/src/codecs/mod.rs b/crates/storage/db/src/codecs/mod.rs index 20ab27ee3..bd3f6f586 100644 --- a/crates/storage/db/src/codecs/mod.rs +++ b/crates/storage/db/src/codecs/mod.rs @@ -2,7 +2,6 @@ pub mod postcard; use katana_primitives::block::FinalityStatus; -use katana_primitives::class::ContractClass; use katana_primitives::contract::ContractAddress; use katana_primitives::Felt; @@ -85,19 +84,6 @@ impl Decode for String { } } -impl Compress for ContractClass { - type Compressed = Vec; - fn compress(self) -> Result { - serde_json::to_vec(&self).map_err(|e| CodecError::Compress(e.to_string())) - } -} - -impl Decompress for ContractClass { - fn decompress>(bytes: B) -> Result { - serde_json::from_slice(bytes.as_ref()).map_err(|e| CodecError::Decode(e.to_string())) - } -} - impl Compress for FinalityStatus { type Compressed = [u8; 1]; fn compress(self) -> Result { diff --git a/crates/storage/db/src/models/mod.rs b/crates/storage/db/src/models/mod.rs index 49bf2a5af..47388db34 100644 --- a/crates/storage/db/src/models/mod.rs +++ b/crates/storage/db/src/models/mod.rs @@ -9,4 +9,5 @@ pub mod trie; pub mod versioned; pub use versioned::block::VersionedHeader; +pub use versioned::class::VersionedContractClass; pub use versioned::transaction::VersionedTx; diff --git a/crates/storage/db/src/models/versioned/class/mod.rs b/crates/storage/db/src/models/versioned/class/mod.rs new file mode 100644 index 000000000..39025bb0f --- /dev/null +++ b/crates/storage/db/src/models/versioned/class/mod.rs @@ -0,0 +1,66 @@ +use serde::{Deserialize, Serialize}; + +mod v7; + +use crate::codecs::{Compress, Decompress}; +use crate::error::CodecError; + +pub type CurrentContractClass = katana_primitives::class::ContractClass; +pub type V8ContractClass = CurrentContractClass; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum VersionedContractClass { + V7(v7::ContractClass), + V8(V8ContractClass), +} + +impl From for VersionedContractClass { + fn from(class: CurrentContractClass) -> Self { + Self::V8(class) + } +} + +impl From for CurrentContractClass { + fn from(versioned: VersionedContractClass) -> Self { + match versioned { + VersionedContractClass::V8(class) => class, + VersionedContractClass::V7(class) => class.into(), + } + } +} + +impl Default for VersionedContractClass { + fn default() -> Self { + Self::V8(CurrentContractClass::Legacy(Default::default())) + } +} + +impl Compress for VersionedContractClass { + type Compressed = Vec; + fn compress(self) -> Result { + serde_json::to_vec(&self).map_err(|e| CodecError::Compress(e.to_string())) + } +} + +impl Decompress for VersionedContractClass { + fn decompress>(bytes: B) -> Result { + let bytes = bytes.as_ref(); + + if let Ok(class) = serde_json::from_slice::(bytes) { + return Ok(class); + } + + // Try deserializing as V8 (current) first, then fall back to V7 + if let Ok(class) = serde_json::from_slice::(bytes) { + return Ok(VersionedContractClass::V8(class)); + } + + if let Ok(class) = serde_json::from_slice::(bytes) { + return Ok(VersionedContractClass::V7(class)); + } + + Err(CodecError::Decompress( + "failed to deserialize contract class: unknown format".to_string(), + )) + } +} diff --git a/crates/storage/db/src/models/versioned/class/v7.rs b/crates/storage/db/src/models/versioned/class/v7.rs new file mode 100644 index 000000000..b31a9bb66 --- /dev/null +++ b/crates/storage/db/src/models/versioned/class/v7.rs @@ -0,0 +1,17 @@ +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone, PartialEq, Eq, ::serde::Serialize, ::serde::Deserialize)] +pub enum ContractClass { + Class(cairo_lang_starknet_classes::contract_class::ContractClass), + Legacy(katana_primitives::class::LegacyContractClass), +} + +impl From for katana_primitives::class::ContractClass { + fn from(contract_class: ContractClass) -> Self { + match contract_class { + ContractClass::Legacy(class) => katana_primitives::class::ContractClass::Legacy(class), + ContractClass::Class(class) => { + katana_primitives::class::ContractClass::Class(class.into()) + } + } + } +} diff --git a/crates/storage/db/src/models/versioned/mod.rs b/crates/storage/db/src/models/versioned/mod.rs index 54493d0c5..7bd868da3 100644 --- a/crates/storage/db/src/models/versioned/mod.rs +++ b/crates/storage/db/src/models/versioned/mod.rs @@ -1,2 +1,3 @@ pub mod block; +pub mod class; pub mod transaction; diff --git a/crates/storage/db/src/tables.rs b/crates/storage/db/src/tables.rs index 18c0373d8..5a1e08957 100644 --- a/crates/storage/db/src/tables.rs +++ b/crates/storage/db/src/tables.rs @@ -1,5 +1,5 @@ use katana_primitives::block::{BlockHash, BlockNumber, FinalityStatus}; -use katana_primitives::class::{ClassHash, CompiledClassHash, ContractClass}; +use katana_primitives::class::{ClassHash, CompiledClassHash}; use katana_primitives::contract::{ContractAddress, GenericContractInfo, StorageKey}; use katana_primitives::execution::TypedTransactionExecutionInfo; use katana_primitives::receipt::Receipt; @@ -12,7 +12,7 @@ use crate::models::list::BlockList; use crate::models::stage::{StageCheckpoint, StageId}; use crate::models::storage::{ContractStorageEntry, ContractStorageKey, StorageEntry}; use crate::models::trie::{TrieDatabaseKey, TrieDatabaseValue, TrieHistoryEntry}; -use crate::models::{VersionedHeader, VersionedTx}; +use crate::models::{VersionedContractClass, VersionedHeader, VersionedTx}; pub trait Key: Encode + Decode + Clone + std::fmt::Debug {} pub trait Value: Compress + Decompress + std::fmt::Debug {} @@ -219,7 +219,7 @@ tables! { /// Store compiled classes CompiledClassHashes: (ClassHash) => CompiledClassHash, /// Store contract classes according to its class hash - Classes: (ClassHash) => ContractClass, + Classes: (ClassHash) => VersionedContractClass, /// Store contract information according to its contract address ContractInfo: (ContractAddress) => GenericContractInfo, /// Store contract storage diff --git a/crates/storage/db/src/version.rs b/crates/storage/db/src/version.rs index 4cb7d6fcd..2339c968f 100644 --- a/crates/storage/db/src/version.rs +++ b/crates/storage/db/src/version.rs @@ -6,7 +6,7 @@ use std::mem; use std::path::{Path, PathBuf}; /// Current version of the database. -pub const CURRENT_DB_VERSION: Version = Version::new(7); +pub const CURRENT_DB_VERSION: Version = Version::new(8); /// Name of the version file. const DB_VERSION_FILE_NAME: &str = "db.version"; @@ -92,6 +92,6 @@ mod tests { #[test] fn test_current_version() { use super::CURRENT_DB_VERSION; - assert_eq!(CURRENT_DB_VERSION.0, 7, "Invalid current database version") + assert_eq!(CURRENT_DB_VERSION.0, 8, "Invalid current database version") } } diff --git a/crates/storage/provider/provider/src/providers/db/mod.rs b/crates/storage/provider/provider/src/providers/db/mod.rs index ff822dcd5..6d61b876c 100644 --- a/crates/storage/provider/provider/src/providers/db/mod.rs +++ b/crates/storage/provider/provider/src/providers/db/mod.rs @@ -744,7 +744,7 @@ impl BlockWriter for DbProvider { // insert all class artifacts for (class_hash, class) in states.classes { - db_tx.put::(class_hash, class)?; + db_tx.put::(class_hash, class.into())?; } // insert compiled class hashes and declarations for declared classes diff --git a/crates/storage/provider/provider/src/providers/db/state.rs b/crates/storage/provider/provider/src/providers/db/state.rs index 4fb2237fe..d5b3b2bcf 100644 --- a/crates/storage/provider/provider/src/providers/db/state.rs +++ b/crates/storage/provider/provider/src/providers/db/state.rs @@ -74,7 +74,7 @@ impl StateWriter for DbProvider { impl ContractClassWriter for DbProvider { fn set_class(&self, hash: ClassHash, class: ContractClass) -> ProviderResult<()> { self.0.update(move |db_tx| -> ProviderResult<()> { - db_tx.put::(hash, class)?; + db_tx.put::(hash, class.into())?; Ok(()) })? } @@ -106,7 +106,7 @@ where Tx: DbTx + Send + Sync, { fn class(&self, hash: ClassHash) -> ProviderResult> { - Ok(self.0.get::(hash)?) + Ok(self.0.get::(hash)?.map(|class| class.into())) } fn compiled_class_hash_of_class_hash( @@ -236,7 +236,7 @@ where { fn class(&self, hash: ClassHash) -> ProviderResult> { if self.is_class_declared_before_block(hash)? { - Ok(self.tx.get::(hash)?) + Ok(self.tx.get::(hash)?.map(Into::into)) } else { Ok(None) } diff --git a/crates/storage/provider/provider/src/providers/fork/state.rs b/crates/storage/provider/provider/src/providers/fork/state.rs index a8db588ea..3829eda57 100644 --- a/crates/storage/provider/provider/src/providers/fork/state.rs +++ b/crates/storage/provider/provider/src/providers/fork/state.rs @@ -76,7 +76,7 @@ where if let Some(class) = self.provider.class(hash)? { Ok(Some(class)) } else if let Some(class) = self.backend.get_class_at(hash)? { - self.db.db().update(|tx| tx.put::(hash, class.clone()))??; + self.db.db().update(|tx| tx.put::(hash, class.clone().into()))??; Ok(Some(class)) } else { Ok(None) @@ -225,7 +225,7 @@ impl ContractClassProvider for HistoricalStateProvider { if let res @ Some(..) = self.provider.class(hash)? { Ok(res) } else if let Some(class) = self.backend.get_class_at(hash)? { - self.db.db().tx_mut()?.put::(hash, class.clone())?; + self.db.db().tx_mut()?.put::(hash, class.clone().into())?; Ok(Some(class)) } else { Ok(None)