diff --git a/Cargo.lock b/Cargo.lock index 7ef7277c4d..4ea6761cef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,10 +362,12 @@ version = "4.5.1" dependencies = [ "arbitrary", "chrono", + "const_panic", "cool_asserts", "educe", "either", "itertools 0.14.0", + "konst", "lalrpop", "lalrpop-util", "linked-hash-map", @@ -573,6 +575,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "const_panic" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652" +dependencies = [ + "typewit", +] + [[package]] name = "cool_asserts" version = "2.0.3" @@ -1302,6 +1313,23 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "konst" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64896bdfd7906cfb0b57bc04f08bde408bcd6aaf71ff438ee471061cd16f2e86" +dependencies = [ + "const_panic", + "konst_proc_macros", + "typewit", +] + +[[package]] +name = "konst_proc_macros" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf92d396aa2df203577ebef8deaf1efc24d446366ca86be83ec8ac794b157d6" + [[package]] name = "lalrpop" version = "0.22.2" @@ -2957,6 +2985,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "typewit" +version = "1.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" + [[package]] name = "ucd-trie" version = "0.1.7" diff --git a/cedar-policy-core/Cargo.toml b/cedar-policy-core/Cargo.toml index 138b57fe32..c40e76b83f 100644 --- a/cedar-policy-core/Cargo.toml +++ b/cedar-policy-core/Cargo.toml @@ -32,6 +32,8 @@ unicode-security = "0.1.0" regex = { version = "1.11", features = ["unicode"]} linked-hash-map = { version = "0.5.6", features = ["serde_impl"] } linked_hash_set = "0.1.5" +konst = "0.4" +const_panic = "0.2" # wasm dependencies serde-wasm-bindgen = { version = "0.6", optional = true } diff --git a/cedar-policy-core/src/ast.rs b/cedar-policy-core/src/ast.rs index e8367a647d..eb54e2d46b 100644 --- a/cedar-policy-core/src/ast.rs +++ b/cedar-policy-core/src/ast.rs @@ -34,6 +34,8 @@ mod name; pub use name::*; mod ops; pub use ops::*; +mod path; +pub use path::*; mod pattern; pub use pattern::*; mod partial_value; diff --git a/cedar-policy-core/src/ast/entity.rs b/cedar-policy-core/src/ast/entity.rs index 7ad28a49f8..a977eb2d8f 100644 --- a/cedar-policy-core/src/ast/entity.rs +++ b/cedar-policy-core/src/ast/entity.rs @@ -33,12 +33,10 @@ use std::sync::Arc; use thiserror::Error; #[cfg(feature = "tolerant-ast")] -static ERROR_NAME: std::sync::LazyLock = - std::sync::LazyLock::new(|| Name(InternalName::from(Id::new_unchecked("EntityTypeError")))); +static ERROR_NAME: Name = Name(InternalName::new_from_path(Id::new_unchecked_from_static("EntityTypeError"), Path::empty(), None)); #[cfg(feature = "tolerant-ast")] -static ERROR_EID_SMOL_STR: std::sync::LazyLock = - std::sync::LazyLock::new(|| SmolStr::from("Eid::ErrorEid")); +static ERROR_EID_SMOL_STR: SmolStr = SmolStr::new_static("Eid::ErrorEid"); #[cfg(feature = "tolerant-ast")] static EID_ERROR_STR: &str = "Eid::Error"; @@ -50,7 +48,7 @@ static ENTITY_TYPE_ERROR_STR: &str = "EntityType::Error"; static ENTITY_UID_ERROR_STR: &str = "EntityUID::Error"; /// The entity type that Actions must have -pub static ACTION_ENTITY_TYPE: &str = "Action"; +pub static ACTION_ENTITY_TYPE: Id = Id::new_unchecked_from_static("Action"); #[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -94,7 +92,7 @@ impl EntityType { pub fn is_action(&self) -> bool { match self { EntityType::EntityType(name) => { - name.as_ref().basename() == &Id::new_unchecked(ACTION_ENTITY_TYPE) + name.as_ref().basename() == &ACTION_ENTITY_TYPE } #[cfg(feature = "tolerant-ast")] EntityType::ErrorEntityType => false, @@ -145,11 +143,16 @@ impl EntityType { pub fn from_normalized_str(src: &str) -> Result { Name::from_normalized_str(src).map(Into::into) } + + /// Convert a [`Name`] to an [`EntityType`] + pub const fn from_name(name: Name) -> Self { + Self::EntityType(name) + } } impl From for EntityType { fn from(n: Name) -> Self { - Self::EntityType(n) + Self::from_name(n) } } @@ -905,7 +908,7 @@ mod test { #[test] fn action_type_is_valid_id() { - assert!(Id::from_normalized_str(ACTION_ENTITY_TYPE).is_ok()); + assert!(Id::from_normalized_str(ACTION_ENTITY_TYPE.as_ref()).is_ok()); } #[cfg(feature = "tolerant-ast")] diff --git a/cedar-policy-core/src/ast/id.rs b/cedar-policy-core/src/ast/id.rs index 14dab56330..60982722bd 100644 --- a/cedar-policy-core/src/ast/id.rs +++ b/cedar-policy-core/src/ast/id.rs @@ -50,6 +50,12 @@ impl Id { Id(s.into()) } + /// Similar to `new_unchecked`, but for static strings which can be `const` + /// constructed + pub(crate) const fn new_unchecked_from_static(s: &'static str) -> Id { + Id(SmolStr::new_static(s)) + } + /// Get the underlying string pub fn into_smolstr(self) -> SmolStr { self.0 @@ -61,6 +67,11 @@ impl Id { pub fn is_reserved(&self) -> bool { self.as_ref() == RESERVED_ID } + + /// Check if the `Id` is static + pub const fn is_static(&self) -> bool { + !self.0.is_heap_allocated() + } } impl AsRef for Id { diff --git a/cedar-policy-core/src/ast/name.rs b/cedar-policy-core/src/ast/name.rs index 428f37e842..78249432aa 100644 --- a/cedar-policy-core/src/ast/name.rs +++ b/cedar-policy-core/src/ast/name.rs @@ -25,9 +25,9 @@ use smol_str::ToSmolStr; use std::collections::HashSet; use std::fmt::Display; use std::str::FromStr; -use std::sync::Arc; use thiserror::Error; +use crate::ast::Path; use crate::parser::err::{ParseError, ParseErrors, ToASTError, ToASTErrorKind}; use crate::parser::Loc; use crate::FromNormalizedStr; @@ -44,7 +44,7 @@ pub struct InternalName { /// Basename pub(crate) id: Id, /// Namespaces - pub(crate) path: Arc>, + pub(crate) path: Path, /// Location of the name in source #[educe(PartialEq(ignore))] #[educe(Hash(ignore))] @@ -76,26 +76,27 @@ impl TryFrom for Id { impl InternalName { /// A full constructor for [`InternalName`] pub fn new(basename: Id, path: impl IntoIterator, loc: Option) -> Self { + Self::new_from_path(basename, path.into_iter().collect(), loc) + } + + /// A full constructor for [`InternalName`] from a [`Path`] + pub const fn new_from_path(basename: Id, path: Path, loc: Option) -> Self { Self { id: basename, - path: Arc::new(path.into_iter().collect()), + path, loc, } } /// Create an [`InternalName`] with no path (no namespaces). pub fn unqualified_name(id: Id, loc: Option) -> Self { - Self { - id, - path: Arc::new(vec![]), - loc, - } + Self::new_from_path(id, Path::empty(), loc) } /// Get the [`InternalName`] representing the reserved `__cedar` namespace pub fn __cedar() -> Self { - // using `Id::new_unchecked()` for performance reasons -- this function is called many times by validator code - Self::unqualified_name(Id::new_unchecked("__cedar"), None) + // using `Id::new_unchecked_from_static()` for performance reasons -- this function is called many times by validator code + Self::unqualified_name(Id::new_unchecked_from_static("__cedar"), None) } /// Create an [`InternalName`] with no path (no namespaces). @@ -103,7 +104,7 @@ impl InternalName { pub fn parse_unqualified_name(s: &str) -> Result { Ok(Self { id: s.parse()?, - path: Arc::new(vec![]), + path: Path::empty(), loc: None, }) } @@ -115,7 +116,7 @@ impl InternalName { namespace: InternalName, loc: Option, ) -> InternalName { - let mut path = Arc::unwrap_or_clone(namespace.path); + let mut path = namespace.path.to_vec(); path.push(namespace.id); InternalName::new(basename, path, loc) } @@ -200,6 +201,11 @@ impl InternalName { .chain(std::iter::once(&self.id)) .any(|id| id.is_reserved()) } + + /// Check if the [`InternalName`] is static + pub const fn is_static(&self) -> bool { + self.id.is_static() && self.path.is_static() + } } impl std::fmt::Display for InternalName { @@ -244,11 +250,10 @@ impl<'a> arbitrary::Arbitrary<'a> for InternalName { let path_size = u.int_in_range(0..=8)?; Ok(Self { id: u.arbitrary()?, - path: Arc::new( - (0..path_size) - .map(|_| u.arbitrary()) - .collect::, _>>()?, - ), + path: (0..path_size) + .map(|_| u.arbitrary()) + .collect::, _>>()? + .into(), loc: None, }) } @@ -416,6 +421,173 @@ impl FromStr for Name { } } +#[derive(Debug)] +/// The error type for [`Name`] validation +pub enum NameValidationError { + /// A reserved keyword was used + ReservedKeyword(&'static str), + /// An empty name was used + Empty, + /// A part of the name started with a non-alphanumeric character + PartStart(char), + /// A part of the name contains a non-alphanumeric character + PartContains(char), +} + +impl std::fmt::Display for NameValidationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NameValidationError::ReservedKeyword(s) => write!(f, "part is a reserved keyword: {s}"), + NameValidationError::Empty => write!(f, "part is empty"), + NameValidationError::PartStart(c) => write!(f, "part starts with invalid char: {c}"), + NameValidationError::PartContains(c) => write!(f, "part contains invalid char: {c}"), + } + } +} + +impl std::error::Error for NameValidationError {} + +impl NameValidationError { + /// Panics with the error message + pub const fn into_panic(&self, msg: &str) -> ! { + match self { + NameValidationError::ReservedKeyword(s) => const_panic::concat_panic!(msg, ": part is a reserved keyword: ", s), + NameValidationError::Empty => const_panic::concat_panic!(msg, ": part is empty"), + NameValidationError::PartStart(c) => const_panic::concat_panic!(msg, ": part starts with invalid char: ", c), + NameValidationError::PartContains(c) => const_panic::concat_panic!(msg, ": part contains invalid char: ", c), + } + } +} + +const fn validate_part(part: &str) -> Option { + let invalid_parts = &["true", "false", "if", "then", "else", "in", "is", "like", "has", "__cedar"]; + konst::for_range! {idx in 0..invalid_parts.len() => + if konst::string::eq_str(invalid_parts[idx], part) { + return Some(NameValidationError::ReservedKeyword(invalid_parts[idx])); + } + } + + let mut chars = konst::string::chars(part); + + let Some(c) = chars.next() else { + return Some(NameValidationError::Empty); + }; + + if c != '_' && !c.is_ascii_alphabetic() { + return Some(NameValidationError::PartStart(c)); + } + + konst::iter::for_each!{c in chars => + if c != '_' && !c.is_ascii_alphanumeric() { + return Some(NameValidationError::PartContains(c)); + } + } + + None +} + +const fn validate(parts: &[&str]) -> Option { + if parts.is_empty() { + return Some(NameValidationError::Empty); + } + + konst::iter::for_each! {part in parts => + if let Some(result) = validate_part(part) { + return Some(result); + } + } + None +} + +/// A name that has been constructed at compile time +#[derive(Debug)] +pub enum CompileTimeName { + /// A valid name + Name(Name), + /// An invalid name + ValidationError(NameValidationError), +} + +impl CompileTimeName { + const EMPTY_NAME: Name = Name(InternalName::new_from_path(Id::new_unchecked_from_static(""), Path::empty(), None)); + + /// Unwrap the name + pub const fn unwrap(mut self) -> Name { + let r = match &mut self { + CompileTimeName::Name(name) => std::mem::replace(name, Self::EMPTY_NAME), + CompileTimeName::ValidationError(r) => r.into_panic("unwrap on invalid name"), + }; + + self.forget(); + r + } + + /// Unwrap the error + pub const fn unwrap_err(mut self) -> NameValidationError { + let r = match &mut self { + CompileTimeName::Name(_) => panic!("unwrap_err on valid name"), + CompileTimeName::ValidationError(r) => std::mem::replace(r, NameValidationError::Empty), + }; + + self.forget(); + r + } + + /// Expect the name + pub const fn expect(mut self, err: &str) -> Name { + let r = match &mut self { + CompileTimeName::Name(name) => std::mem::replace(name, Self::EMPTY_NAME), + CompileTimeName::ValidationError(r) => r.into_panic(err), + }; + + self.forget(); + r + } + + /// Expect the name + pub const fn expect_err(mut self, err: &str) -> NameValidationError { + let r = match &mut self { + CompileTimeName::Name(_) => const_panic::concat_panic!(err), + CompileTimeName::ValidationError(r) => std::mem::replace(r, NameValidationError::Empty), + }; + + self.forget(); + r + } + + const fn forget(self) { + if let CompileTimeName::Name(name) = &self { + assert!(name.0.is_static(), "name should be static when forgetting"); + } + std::mem::forget(self); + } +} + +/// This is how we can create a `Name` at compile time +#[macro_export] +macro_rules! make_name { + ($input:expr) => { + const { + const EXPR: &str = $input; + const PARTS_STR: &[&str] = &konst::iter::collect_const!(&str => + konst::string::split(EXPR, "::"), + ); + const PARTS_ID: &[Id] = &konst::iter::collect_const!(Id => + konst::slice::iter(konst::slice::slice_up_to(PARTS_STR, PARTS_STR.len() - 1)), + map(|part| { + Id::new_unchecked_from_static(part) + }), + ); + + if let Some(err) = validate(PARTS_STR) { + CompileTimeName::ValidationError(err) + } else { + CompileTimeName::Name(Name(InternalName::new_from_path(Id::new_unchecked_from_static(PARTS_STR[PARTS_STR.len() - 1]), Path::new_from_static(&PARTS_ID), None))) + } + } + } +} + // PANIC SAFETY: this is a valid Regex pattern #[allow(clippy::unwrap_used)] static VALID_NAME_REGEX: std::sync::LazyLock = std::sync::LazyLock::new(|| { @@ -638,6 +810,17 @@ impl<'a> arbitrary::Arbitrary<'a> for Name { mod test { use super::*; + #[test] + fn compile_time_name() { + const _: Name = make_name!("foo").expect("should be OK"); + const _: Name = make_name!("foo::bar").expect("should be OK"); + const _: NameValidationError = make_name!(r#"foo::"bar""#).expect_err("shouldn't be OK"); + const _: NameValidationError = make_name!(" foo").expect_err("shouldn't be OK"); + const _: NameValidationError = make_name!("foo ").expect_err("shouldn't be OK"); + const _: NameValidationError = make_name!("foo\n").expect_err("shouldn't be OK"); + const _: NameValidationError = make_name!("foo//comment").expect_err("shouldn't be OK"); + } + #[test] fn normalized_name() { InternalName::from_normalized_str("foo").expect("should be OK"); diff --git a/cedar-policy-core/src/ast/path.rs b/cedar-policy-core/src/ast/path.rs new file mode 100644 index 0000000000..0772df9348 --- /dev/null +++ b/cedar-policy-core/src/ast/path.rs @@ -0,0 +1,156 @@ +use std::sync::Arc; + +use crate::ast::Id; + +/// A path containing a list of identifiers +#[derive(Clone)] +pub enum Path { + /// An owned list of identifiers + Arc(Arc<[Id]>), + /// A static list of identifiers + Static(&'static [Id]), +} + +impl From> for Path { + fn from(value: Vec) -> Self { + Self::Arc(value.into()) + } +} + +impl From> for Path { + fn from(value: Arc<[Id]>) -> Self { + Self::Arc(value) + } +} + +impl From<&'static [Id]> for Path { + fn from(value: &'static [Id]) -> Self { + Self::Static(value) + } +} + +impl AsRef<[Id]> for Path { + fn as_ref(&self) -> &[Id] { + match self { + Self::Arc(value) => value.as_ref(), + Self::Static(value) => value, + } + } +} + +impl<'a> IntoIterator for &'a Path { + type Item = &'a Id; + type IntoIter = std::slice::Iter<'a, Id>; + fn into_iter(self) -> Self::IntoIter { + self.as_ref().iter() + } +} + +impl FromIterator for Path { + fn from_iter>(iter: T) -> Self { + Self::Arc(iter.into_iter().collect()) + } +} + +impl Path { + /// Create a new [`Path`] from an iterator + pub fn new(iter: impl IntoIterator) -> Self { + Self::from_iter(iter) + } + + /// Create a new [`Path`] from a static slice + pub const fn new_from_static(slice: &'static [Id]) -> Self { + Self::Static(slice) + } + + /// Create a new [`Path`] from an arc + pub const fn new_from_arc(ptr: Arc<[Id]>) -> Self { + Self::Arc(ptr) + } + + /// Create a new [`Path`] with no elements + pub const fn empty() -> Self { + Self::Static(&[]) + } + + /// Convert a [`Path`] to a [`Vec`] + pub fn to_vec(&self) -> Vec { + self.as_ref().to_vec() + } + + /// Borrowed iteration of the [`Path`]'s elements + pub fn iter(&self) -> impl Iterator { + self.into_iter() + } + + /// Check if the [`Path`] is empty + pub fn is_empty(&self) -> bool { + self.as_ref().is_empty() + } + + /// Convert a [`Path`] to a slice + pub fn as_slice(&self) -> &[Id] { + self.as_ref() + } + + /// Check if the [`Path`] is static + pub const fn is_static(&self) -> bool { + matches!(self, Self::Static(_)) + } +} + +impl std::fmt::Debug for Path { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_slice().fmt(f) + } +} + +impl PartialEq<[Id]> for Path { + fn eq(&self, other: &[Id]) -> bool { + self.as_slice() == other + } +} + +impl PartialEq for [Id] { + fn eq(&self, other: &Path) -> bool { + other == self + } +} + +impl PartialEq for Path { + fn eq(&self, other: &Path) -> bool { + self.as_slice() == other.as_slice() + } +} + +impl std::hash::Hash for Path { + fn hash(&self, state: &mut H) { + self.as_slice().hash(state) + } +} + +impl Eq for Path {} + +impl PartialOrd<[Id]> for Path { + fn partial_cmp(&self, other: &[Id]) -> Option { + self.as_slice().partial_cmp(other) + } +} + +impl PartialOrd for [Id] { + fn partial_cmp(&self, other: &Path) -> Option { + other.partial_cmp(self) + } +} + +impl PartialOrd for Path { + fn partial_cmp(&self, other: &Path) -> Option { + self.as_slice().partial_cmp(other.as_slice()) + } +} + +impl Ord for Path { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.as_slice().cmp(other.as_slice()) + } +} diff --git a/cedar-policy-core/src/parser/cst_to_ast.rs b/cedar-policy-core/src/parser/cst_to_ast.rs index 5376090ed5..7664001af7 100644 --- a/cedar-policy-core/src/parser/cst_to_ast.rs +++ b/cedar-policy-core/src/parser/cst_to_ast.rs @@ -2338,7 +2338,7 @@ fn construct_string_from_var(v: ast::Var) -> SmolStr { fn construct_name(path: Vec, id: ast::Id, loc: Option) -> ast::InternalName { ast::InternalName { id, - path: Arc::new(path), + path: path.into(), loc, } } diff --git a/cedar-policy-core/src/validator/diagnostics/validation_errors.rs b/cedar-policy-core/src/validator/diagnostics/validation_errors.rs index 50874203a0..35defcdd06 100644 --- a/cedar-policy-core/src/validator/diagnostics/validation_errors.rs +++ b/cedar-policy-core/src/validator/diagnostics/validation_errors.rs @@ -718,7 +718,7 @@ mod test_attr_access { ) { let env = RequestEnv::DeclaredAction { principal: &"Principal".parse().unwrap(), - action: &EntityUID::with_eid_and_type(crate::ast::ACTION_ENTITY_TYPE, "action") + action: &EntityUID::with_eid_and_type(crate::ast::ACTION_ENTITY_TYPE.as_ref(), "action") .unwrap(), resource: &"Resource".parse().unwrap(), context: &Type::record_with_attributes(None, OpenTag::ClosedAttributes), diff --git a/cedar-policy-core/src/validator/schema/namespace_def.rs b/cedar-policy-core/src/validator/schema/namespace_def.rs index 5b2328a058..82b8a0b24b 100644 --- a/cedar-policy-core/src/validator/schema/namespace_def.rs +++ b/cedar-policy-core/src/validator/schema/namespace_def.rs @@ -32,7 +32,7 @@ use crate::{ }; use itertools::Itertools; use nonempty::{nonempty, NonEmpty}; -use smol_str::{SmolStr, ToSmolStr}; +use smol_str::SmolStr; use super::{internal_name_to_entity_type, AllDefs, LocatedType, ValidatorApplySpec}; use crate::validator::{ @@ -248,7 +248,7 @@ impl ValidatorNamespaceDef { // The `name` in an entity type declaration cannot be qualified // with a namespace (it always implicitly takes the schema // namespace), so we do this comparison directly. - .any(|(name, _)| name.to_smolstr() == crate::ast::ACTION_ENTITY_TYPE) + .any(|(name, _)| name.0 == crate::ast::ACTION_ENTITY_TYPE) { return Err(ActionEntityTypeDeclaredError {}.into()); } diff --git a/cedar-policy-core/src/validator/typecheck/test/test_utils.rs b/cedar-policy-core/src/validator/typecheck/test/test_utils.rs index 7c2d6c6320..9b439e61e5 100644 --- a/cedar-policy-core/src/validator/typecheck/test/test_utils.rs +++ b/cedar-policy-core/src/validator/typecheck/test/test_utils.rs @@ -99,7 +99,7 @@ impl Typechecker<'_> { principal: &"Principal" .parse() .expect("Placeholder type \"Principal\" failed to parse as valid type name."), - action: &EntityUID::with_eid_and_type(ACTION_ENTITY_TYPE, "action") + action: &EntityUID::with_eid_and_type(ACTION_ENTITY_TYPE.as_ref(), "action") .expect("ACTION_ENTITY_TYPE failed to parse as type name."), resource: &"Resource" .parse()