|
16 | 16 |
|
17 | 17 | use cedar_policy_core::ast::{Id, InternalName, Name}; |
18 | 18 | use cedar_policy_validator::json_schema; |
| 19 | +use cedar_policy_validator::json_schema::EntityTypeKind; |
19 | 20 | use cedar_policy_validator::RawName; |
| 21 | +use cedar_policy_validator::ValidatorEntityTypeKind; |
20 | 22 | use itertools::Itertools; |
21 | 23 | use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; |
22 | 24 |
|
@@ -228,36 +230,67 @@ impl<N: Clone + PartialEq + Debug + Display + TypeName + Ord> Equiv for json_sch |
228 | 230 | fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { |
229 | 231 | Equiv::equiv(&lhs.annotations, &rhs.annotations) |
230 | 232 | .map_err(|e| format!("mismatch in entity annotations: {e}"))?; |
231 | | - Equiv::equiv( |
232 | | - &lhs.member_of_types.iter().collect::<BTreeSet<_>>(), |
233 | | - &rhs.member_of_types.iter().collect::<BTreeSet<_>>(), |
234 | | - ) |
235 | | - .map_err(|e| format!("memberOfTypes are not equal: {e}"))?; |
236 | | - Equiv::equiv(&lhs.shape, &rhs.shape).map_err(|e| format!("mismatched types: {e}"))?; |
237 | | - match (&lhs.tags, &rhs.tags) { |
238 | | - (Some(ts1), Some(ts2)) => { |
239 | | - Equiv::equiv(ts1, ts2).map_err(|msg| format!("mismatched entity tags: {msg}")) |
| 233 | + match (&lhs.kind, &rhs.kind) { |
| 234 | + (EntityTypeKind::Enum { choices: c1 }, EntityTypeKind::Enum { choices: c2 }) => { |
| 235 | + if c1 != c2 { |
| 236 | + Err(format!( |
| 237 | + "enumerated entity types have different eid choices: {c1:?} and {c2:?}" |
| 238 | + )) |
| 239 | + } else { |
| 240 | + Ok(()) |
| 241 | + } |
| 242 | + } |
| 243 | + (EntityTypeKind::Standard(lhs), EntityTypeKind::Standard(rhs)) => { |
| 244 | + Equiv::equiv( |
| 245 | + &lhs.member_of_types.iter().collect::<BTreeSet<_>>(), |
| 246 | + &rhs.member_of_types.iter().collect::<BTreeSet<_>>(), |
| 247 | + ) |
| 248 | + .map_err(|e| format!("memberOfTypes are not equal: {e}"))?; |
| 249 | + Equiv::equiv(&lhs.shape, &rhs.shape) |
| 250 | + .map_err(|e| format!("mismatched types: {e}"))?; |
| 251 | + match (&lhs.tags, &rhs.tags) { |
| 252 | + (Some(ts1), Some(ts2)) => Equiv::equiv(ts1, ts2) |
| 253 | + .map_err(|msg| format!("mismatched entity tags: {msg}")), |
| 254 | + (None, None) => Ok(()), |
| 255 | + (Some(ts), None) | (None, Some(ts)) => { |
| 256 | + Err(format!("only one side has tags: {ts}")) |
| 257 | + } |
| 258 | + } |
240 | 259 | } |
241 | | - (None, None) => Ok(()), |
242 | | - (Some(ts), None) | (None, Some(ts)) => Err(format!("only one side has tags: {ts}")), |
| 260 | + (k1, k2) => Err(format!("different entity type kind: {:?} and {:?}", k1, k2)), |
243 | 261 | } |
244 | 262 | } |
245 | 263 | } |
246 | 264 |
|
247 | 265 | impl Equiv for cedar_policy_validator::ValidatorEntityType { |
248 | 266 | fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { |
249 | | - Equiv::equiv(&lhs.descendants, &rhs.descendants)?; |
250 | | - Equiv::equiv( |
251 | | - &lhs.attributes().collect::<HashMap<_, _>>(), |
252 | | - &rhs.attributes().collect::<HashMap<_, _>>(), |
253 | | - )?; |
254 | | - if lhs.tag_type() != rhs.tag_type() { |
255 | | - return Err(format!( |
256 | | - "encountered different tags types: {:?} and {:?}", |
257 | | - lhs.tag_type(), |
258 | | - rhs.tag_type() |
259 | | - )); |
260 | | - } |
| 267 | + match (&lhs.kind, &rhs.kind) { |
| 268 | + (ValidatorEntityTypeKind::Enum(c1), ValidatorEntityTypeKind::Enum(c2)) => { |
| 269 | + if c1 != c2 { |
| 270 | + return Err(format!( |
| 271 | + "enumerated entity types have different eid choices: {c1:?} and {c2:?}" |
| 272 | + )); |
| 273 | + } |
| 274 | + } |
| 275 | + (ValidatorEntityTypeKind::Standard(_), ValidatorEntityTypeKind::Standard(_)) => { |
| 276 | + Equiv::equiv(&lhs.descendants, &rhs.descendants)?; |
| 277 | + Equiv::equiv( |
| 278 | + &lhs.attributes().iter().collect::<HashMap<_, _>>(), |
| 279 | + &rhs.attributes().iter().collect::<HashMap<_, _>>(), |
| 280 | + )?; |
| 281 | + if lhs.tag_type() != rhs.tag_type() { |
| 282 | + return Err(format!( |
| 283 | + "encountered different tags types: {:?} and {:?}", |
| 284 | + lhs.tag_type(), |
| 285 | + rhs.tag_type() |
| 286 | + )); |
| 287 | + } |
| 288 | + } |
| 289 | + (k1, k2) => { |
| 290 | + return Err(format!("different entity type kind: {:?} and {:?}", k1, k2)); |
| 291 | + } |
| 292 | + }; |
| 293 | + |
261 | 294 | Ok(()) |
262 | 295 | } |
263 | 296 | } |
|
0 commit comments