diff --git a/engine/src/execution_context.rs b/engine/src/execution_context.rs index a47ab632..62f32cd8 100644 --- a/engine/src/execution_context.rs +++ b/engine/src/execution_context.rs @@ -1,7 +1,7 @@ use crate::{ scheme::{Field, List, Scheme, SchemeMismatchError}, types::{GetType, LhsValue, LhsValueSeed, Type, TypeMismatchError}, - ListMatcher, + ListMatcher, UnknownFieldError, }; use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, Visitor}; use serde::ser::{SerializeMap, SerializeSeq, Serializer}; @@ -16,11 +16,15 @@ use thiserror::Error; pub enum SetFieldValueError { /// An error that occurs when trying to assign a value of the wrong type to a field. #[error("{0}")] - TypeMismatchError(#[source] TypeMismatchError), + TypeMismatch(#[source] TypeMismatchError), /// An error that occurs when trying to set the value of a field from a different scheme. #[error("{0}")] - SchemeMismatchError(#[source] SchemeMismatchError), + SchemeMismatch(#[source] SchemeMismatchError), + + /// An error that occurs when specifying an unknown field name. + #[error("{0}")] + UnknownField(#[source] UnknownFieldError), } /// An error that occurs when previously defined list gets redefined. @@ -82,7 +86,7 @@ impl<'e, U> ExecutionContext<'e, U> { value: V, ) -> Result>, SetFieldValueError> { if !std::ptr::eq(self.scheme, field.scheme()) { - return Err(SetFieldValueError::SchemeMismatchError(SchemeMismatchError)); + return Err(SetFieldValueError::SchemeMismatch(SchemeMismatchError)); } let value = value.into(); @@ -92,7 +96,32 @@ impl<'e, U> ExecutionContext<'e, U> { if field_type == value_type { Ok(self.values[field.index()].replace(value)) } else { - Err(SetFieldValueError::TypeMismatchError(TypeMismatchError { + Err(SetFieldValueError::TypeMismatch(TypeMismatchError { + expected: field_type.into(), + actual: value_type, + })) + } + } + + /// Sets a runtime value for a given field name. + pub fn set_field_value_from_name<'v: 'e, V: Into>>( + &mut self, + name: &str, + value: V, + ) -> Result>, SetFieldValueError> { + let field = self + .scheme + .get_field(name) + .map_err(SetFieldValueError::UnknownField)?; + let value = value.into(); + + let field_type = field.get_type(); + let value_type = value.get_type(); + + if field_type == value_type { + Ok(self.values[field.index()].replace(value)) + } else { + Err(SetFieldValueError::TypeMismatch(TypeMismatchError { expected: field_type.into(), actual: value_type, })) @@ -138,6 +167,12 @@ impl<'e, U> ExecutionContext<'e, U> { &*self.list_matchers[list.index()] } + /// Get the list matcher object for the specified type. + pub fn get_list_matcher_from_type(&self, ty: &Type) -> Option<&dyn ListMatcher> { + let list = self.scheme.get_list(ty)?; + Some(&*self.list_matchers[list.index()]) + } + /// Get the list matcher object for the specified list type. pub fn get_list_matcher_mut(&mut self, list: List<'_>) -> &mut dyn ListMatcher { assert!(self.scheme() == list.scheme()); @@ -145,6 +180,12 @@ impl<'e, U> ExecutionContext<'e, U> { &mut *self.list_matchers[list.index()] } + /// Get the list matcher object for the specified type. + pub fn get_list_matcher_mut_from_type(&mut self, ty: &Type) -> Option<&mut dyn ListMatcher> { + let list = self.scheme.get_list(ty)?; + Some(&mut *self.list_matchers[list.index()]) + } + /// Get immutable reference to user data stored in /// this execution context with [`ExecutionContext::new_with`]. #[inline] @@ -285,18 +326,18 @@ impl<'de, U> DeserializeSeed<'de> for &mut ExecutionContext<'de, U> { .map_err(|_| de::Error::custom(format!("unknown field: {key}")))?; let value = access .next_value_seed::>(LhsValueSeed(&field.get_type()))?; - let field = self - .0 - .scheme() - .get_field(&key) - .map_err(|_| de::Error::custom(format!("unknown field: {key}")))?; - self.0.set_field_value(field, value).map_err(|e| match e { - SetFieldValueError::TypeMismatchError(e) => de::Error::custom(format!( - "invalid type: {:?}, expected {:?}", - e.actual, e.expected - )), - SetFieldValueError::SchemeMismatchError(_) => unreachable!(), - })?; + self.0 + .set_field_value_from_name(&key, value) + .map_err(|e| match e { + SetFieldValueError::UnknownField(UnknownFieldError) => { + de::Error::custom(format!("unknown field name `{key}`",)) + } + SetFieldValueError::TypeMismatch(e) => de::Error::custom(format!( + "invalid type: {:?}, expected {:?}", + e.actual, e.expected + )), + SetFieldValueError::SchemeMismatch(_) => unreachable!(), + })?; } } @@ -361,7 +402,7 @@ fn test_field_value_type_mismatch() { assert_eq!( ctx.set_field_value(scheme.get_field("foo").unwrap(), LhsValue::Bool(false)), - Err(SetFieldValueError::TypeMismatchError(TypeMismatchError { + Err(SetFieldValueError::TypeMismatch(TypeMismatchError { expected: Type::Int.into(), actual: Type::Bool, })) @@ -378,9 +419,7 @@ fn test_scheme_mismatch() { assert_eq!( ctx.set_field_value(scheme2.get_field("foo").unwrap(), LhsValue::Bool(false)), - Err(SetFieldValueError::SchemeMismatchError( - SchemeMismatchError {} - )) + Err(SetFieldValueError::SchemeMismatch(SchemeMismatchError {})) ); }