diff --git a/cedar-language-server/src/schema/fold.rs b/cedar-language-server/src/schema/fold.rs index 29bdde9f59..671ab8fb5d 100644 --- a/cedar-language-server/src/schema/fold.rs +++ b/cedar-language-server/src/schema/fold.rs @@ -80,7 +80,7 @@ pub(crate) fn fold_schema(schema_info: &SchemaInfo) -> Option> .filter_map(|et| et.loc.as_loc_ref()); let action_locs = validator.action_ids().filter_map(|a| a.loc()); let common_types = validator - .common_types() + .common_types_extended() .filter_map(|ct| ct.type_loc.as_loc_ref()); // Combine all locations and create folding ranges diff --git a/cedar-language-server/src/schema/symbols.rs b/cedar-language-server/src/schema/symbols.rs index b0be0e6cc2..602f533746 100644 --- a/cedar-language-server/src/schema/symbols.rs +++ b/cedar-language-server/src/schema/symbols.rs @@ -125,7 +125,7 @@ pub(crate) fn schema_symbols(schema_info: &SchemaInfo) -> Option = validator - .common_types() + .common_types_extended() .filter_map(|ct| { ct.name_loc .as_ref() diff --git a/cedar-policy-core/Cargo.toml b/cedar-policy-core/Cargo.toml index 9f442a6427..b71ff4fe49 100644 --- a/cedar-policy-core/Cargo.toml +++ b/cedar-policy-core/Cargo.toml @@ -62,6 +62,7 @@ wasm = ["serde-wasm-bindgen", "tsify", "wasm-bindgen"] experimental = ["tpe", "tolerant-ast", "extended-schema", "entity-manifest", "partial-validate", "partial-eval"] extended-schema = [] tpe = [] +generalized-templates = [] # Feature for raw parsing raw-parsing = [] diff --git a/cedar-policy-core/src/ast.rs b/cedar-policy-core/src/ast.rs index e8367a647d..00c1f2e9ef 100644 --- a/cedar-policy-core/src/ast.rs +++ b/cedar-policy-core/src/ast.rs @@ -54,5 +54,7 @@ mod expr_iterator; pub use expr_iterator::*; mod annotation; pub use annotation::*; +mod slots_type_declaration; +pub use slots_type_declaration::*; mod expr_visitor; pub use expr_visitor::*; diff --git a/cedar-policy-core/src/ast/expr.rs b/cedar-policy-core/src/ast/expr.rs index ee3c567ed8..66903de6f9 100644 --- a/cedar-policy-core/src/ast/expr.rs +++ b/cedar-policy-core/src/ast/expr.rs @@ -293,13 +293,27 @@ impl Expr { self.subexpressions() .filter_map(|exp| match &exp.expr_kind { ExprKind::Slot(slotid) => Some(Slot { - id: *slotid, + id: slotid.clone(), loc: exp.source_loc().into_maybe_loc(), }), _ => None, }) } + /// Iterate over all of the principal and resource slots in this policy AST + pub fn principal_and_resource_slots(&self) -> impl Iterator + '_ { + self.subexpressions() + .filter_map(|exp| match &exp.expr_kind { + ExprKind::Slot(slotid) if slotid.is_principal() || slotid.is_resource() => { + Some(Slot { + id: slotid.clone(), + loc: exp.source_loc().into_maybe_loc(), + }) + } + _ => None, + }) + } + /// Determine if the expression is projectable under partial evaluation /// An expression is projectable if it's guaranteed to never error on evaluation /// This is true if the expression is entirely composed of values or unknowns @@ -1842,7 +1856,7 @@ mod test { let e = Expr::slot(SlotId::principal()); let p = SlotId::principal(); let r = SlotId::resource(); - let set: HashSet = HashSet::from_iter([p]); + let set: HashSet = HashSet::from_iter([p.clone()]); assert_eq!(set, e.slots().map(|slot| slot.id).collect::>()); let e = Expr::or( Expr::slot(SlotId::principal()), diff --git a/cedar-policy-core/src/ast/expr_visitor.rs b/cedar-policy-core/src/ast/expr_visitor.rs index ee6acd3f26..c5adb9b353 100644 --- a/cedar-policy-core/src/ast/expr_visitor.rs +++ b/cedar-policy-core/src/ast/expr_visitor.rs @@ -55,7 +55,7 @@ pub trait ExprVisitor { match expr.expr_kind() { ExprKind::Lit(lit) => self.visit_literal(lit, loc), ExprKind::Var(var) => self.visit_var(*var, loc), - ExprKind::Slot(slot) => self.visit_slot(*slot, loc), + ExprKind::Slot(slot) => self.visit_slot(slot.clone(), loc), ExprKind::Unknown(unknown) => self.visit_unknown(unknown, loc), ExprKind::If { test_expr, diff --git a/cedar-policy-core/src/ast/name.rs b/cedar-policy-core/src/ast/name.rs index 8482cf6e39..72430ae8d2 100644 --- a/cedar-policy-core/src/ast/name.rs +++ b/cedar-policy-core/src/ast/name.rs @@ -21,6 +21,7 @@ use miette::Diagnostic; use ref_cast::RefCast; use regex::Regex; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_with::{serde_as, DisplayFromStr}; use smol_str::ToSmolStr; use std::collections::HashSet; use std::fmt::Display; @@ -283,9 +284,10 @@ impl<'de> Deserialize<'de> for InternalName { /// Clone is O(1). // This simply wraps a separate enum -- currently [`ValidSlotId`] -- in case we // want to generalize later -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[serde_as] +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] #[serde(transparent)] -pub struct SlotId(pub(crate) ValidSlotId); +pub struct SlotId(#[serde_as(as = "DisplayFromStr")] pub(crate) ValidSlotId); impl SlotId { /// Get the slot for `principal` @@ -298,6 +300,11 @@ impl SlotId { Self(ValidSlotId::Resource) } + /// Create a `generalized slot` + pub fn generalized_slot(id: Id) -> Self { + Self(ValidSlotId::GeneralizedSlot(id)) + } + /// Check if a slot represents a principal pub fn is_principal(&self) -> bool { matches!(self, Self(ValidSlotId::Principal)) @@ -307,6 +314,11 @@ impl SlotId { pub fn is_resource(&self) -> bool { matches!(self, Self(ValidSlotId::Resource)) } + + /// Check if a slot represents a generalized slot + pub fn is_generalized_slot(&self) -> bool { + matches!(self, Self(ValidSlotId::GeneralizedSlot(_))) + } } impl From for SlotId { @@ -318,28 +330,44 @@ impl From for SlotId { } } +impl std::fmt::Display for ValidSlotId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + ValidSlotId::Principal => "principal", + ValidSlotId::Resource => "resource", + ValidSlotId::GeneralizedSlot(id) => id.as_ref(), + }; + write!(f, "?{s}") + } +} + +impl FromStr for SlotId { + type Err = ParseErrors; + + fn from_str(s: &str) -> Result { + s.parse().map(SlotId) + } +} + impl std::fmt::Display for SlotId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } -/// Two possible variants for Slots -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +/// Three possible variants for Slots +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] pub(crate) enum ValidSlotId { - #[serde(rename = "?principal")] Principal, - #[serde(rename = "?resource")] Resource, + GeneralizedSlot(Id), // Slots for generalized templates, for more info see [RFC 98](https://github.com/cedar-policy/rfcs/pull/98). } -impl std::fmt::Display for ValidSlotId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = match self { - ValidSlotId::Principal => "principal", - ValidSlotId::Resource => "resource", - }; - write!(f, "?{s}") +impl FromStr for ValidSlotId { + type Err = ParseErrors; + + fn from_str(s: &str) -> Result { + crate::parser::parse_slot(s) } } diff --git a/cedar-policy-core/src/ast/policy.rs b/cedar-policy-core/src/ast/policy.rs index 191d66079e..3c041e8009 100644 --- a/cedar-policy-core/src/ast/policy.rs +++ b/cedar-policy-core/src/ast/policy.rs @@ -15,13 +15,20 @@ */ use crate::ast::*; +use crate::entities::{conformance::typecheck_restricted_expr_against_schematype, SchemaType}; +use crate::extensions::Extensions; use crate::parser::{AsLocRef, IntoMaybeLoc, Loc, MaybeLoc}; +use crate::validator::{ + err::SchemaError, json_schema::Type as JSONSchemaType, types::Type as ValidatorType, RawName, + ValidatorSchema, +}; use annotation::{Annotation, Annotations}; use educe::Educe; use itertools::Itertools; use miette::Diagnostic; use nonempty::{nonempty, NonEmpty}; use serde::{Deserialize, Serialize}; +use slots_type_declaration::SlotsTypeDeclaration; use smol_str::SmolStr; use std::{ collections::{HashMap, HashSet}, @@ -53,6 +60,9 @@ cfg_tolerant_ast! { static DEFAULT_ANNOTATIONS: std::sync::LazyLock> = std::sync::LazyLock::new(|| Arc::new(Annotations::default())); + static DEFAULT_SLOTS_TYPE_DECLARATION: std::sync::LazyLock> = + std::sync::LazyLock::new(|| Arc::new(SlotsTypeDeclaration::default())); + static DEFAULT_PRINCIPAL_CONSTRAINT: std::sync::LazyLock = std::sync::LazyLock::new(PrincipalConstraint::any); @@ -120,6 +130,7 @@ impl Template { id: PolicyID, loc: MaybeLoc, annotations: Annotations, + slots_type_declaration: SlotsTypeDeclaration, effect: Effect, principal_constraint: PrincipalConstraint, action_constraint: ActionConstraint, @@ -130,6 +141,7 @@ impl Template { id, loc, annotations, + slots_type_declaration, effect, principal_constraint, action_constraint, @@ -154,6 +166,7 @@ impl Template { id: PolicyID, loc: MaybeLoc, annotations: Arc, + slots_type_declaration: Arc, effect: Effect, principal_constraint: PrincipalConstraint, action_constraint: ActionConstraint, @@ -164,6 +177,7 @@ impl Template { id, loc, annotations, + slots_type_declaration, effect, principal_constraint, action_constraint, @@ -238,6 +252,18 @@ impl Template { self.body.annotations_arc() } + /// Get all slots_type_declaration data. + pub fn slots_type_declaration( + &self, + ) -> impl Iterator)> { + self.body.slots_type_declaration() + } + + /// Get [`Arc`] owning the slots_type_declaration data. + pub fn slots_type_declaration_arc(&self) -> &Arc { + self.body.slots_type_declaration_arc() + } + /// Get the condition expression of this template. /// /// This will be a conjunction of the template's scope constraints (on @@ -247,11 +273,25 @@ impl Template { self.body.condition() } - /// List of open slots in this template + /// List of open slots in this template including principal, resource, and generalized slots pub fn slots(&self) -> impl Iterator { self.slots.iter() } + /// List of principal and resource slots in this template + pub fn principal_resource_slots(&self) -> impl Iterator { + self.slots + .iter() + .filter(|slot| slot.id.is_principal() || slot.id.is_resource()) + } + + /// List of generalized slots in this template + pub fn generalized_slots(&self) -> impl Iterator { + self.slots + .iter() + .filter(|slot| slot.id.is_generalized_slot()) + } + /// Check if this template is a static policy /// /// Static policies can be linked without any slots, @@ -263,18 +303,40 @@ impl Template { /// Ensure that every slot in the template is bound by values, /// and that no extra values are bound in values /// This upholds invariant (values total map) + /// + /// All callers of this function + /// must enforce INVARIANT that `?principal` and `?resource` slots + /// are in values and generalized slots are in generalized_values pub fn check_binding( template: &Template, values: &HashMap, + generalized_values: &HashMap, ) -> Result<(), LinkingError> { // Verify all slots bound - let unbound = template + let unbound_values_and_generalized_values = template .slots .iter() - .filter(|slot| !values.contains_key(&slot.id)) + .filter(|slot| { + !values.contains_key(&slot.id) && !generalized_values.contains_key(&slot.id) + }) + .collect::>(); + + let extra_values = values + .iter() + .filter_map(|(slot, _)| { + if !template + .slots + .iter() + .any(|template_slot| template_slot.id == *slot) + { + Some(slot) + } else { + None + } + }) .collect::>(); - let extra = values + let extra_generalized_values = generalized_values .iter() .filter_map(|(slot, _)| { if !template @@ -289,16 +351,75 @@ impl Template { }) .collect::>(); - if unbound.is_empty() && extra.is_empty() { + if unbound_values_and_generalized_values.is_empty() + && extra_values.is_empty() + && extra_generalized_values.is_empty() + { Ok(()) } else { Err(LinkingError::from_unbound_and_extras( - unbound.into_iter().map(|slot| slot.id), - extra.into_iter().copied(), + unbound_values_and_generalized_values + .into_iter() + .map(|slot| slot.id.clone()), + extra_values + .into_iter() + .cloned() + .chain(extra_generalized_values.into_iter().cloned()), )) } } + /// Validates that the values provided for the generalized slots are of the types declared + pub fn link_time_type_checking( + template: &Template, + schema: Option<&ValidatorSchema>, + values: &HashMap, + generalized_values: &HashMap, + ) -> Result<(), LinkingError> { + let slots_type_declaration = SlotsTypeDeclaration::from_iter( + template + .slots_type_declaration() + .map(|(k, v)| (k.clone(), v.clone())), + ); + let validator_slots_type_declaration = match schema { + Some(schema) => slots_type_declaration.into_validator_slots_type_declaration(schema)?, + None => { + slots_type_declaration.into_validator_slots_type_declaration_without_schema()? + } + }; + + // Loop through all slots that have a type declaration and check that + // the values provided are of that type + for (slot, validator_type) in validator_slots_type_declaration.0 { + // PANIC SAFETY + // all slot values should binded + let restricted_expr = if slot.is_principal() || slot.is_resource() { + #[allow(clippy::unwrap_used)] + RestrictedExpr::val(values.get(&slot).unwrap().clone()) + } else { + #[allow(clippy::unwrap_used)] + generalized_values.get(&slot).unwrap().clone() + }; + let borrowed_restricted_expr = restricted_expr.as_borrowed(); + #[allow(clippy::expect_used)] + let schema_ty = &SchemaType::try_from(validator_type.clone()) + .expect("This should never happen as expected_ty is a statically annotated type"); + let extensions = Extensions::all_available(); + typecheck_restricted_expr_against_schematype( + borrowed_restricted_expr, + schema_ty, + extensions, + ) + .map_err(|_| LinkingError::ValueProvidedForSlotIsNotOfTypeSpecified { + slot: slot.clone(), + value: restricted_expr.clone(), + ty: validator_type.clone(), + })? + } + + Ok(()) + } + /// Attempt to create a template-linked policy from this template. /// This will fail if values for all open slots are not given. /// `new_instance_id` is the `PolicyId` for the created template-linked policy. @@ -306,10 +427,15 @@ impl Template { template: Arc