Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions engine/src/execution_context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
scheme::{Field, List, Scheme, SchemeMismatchError},
types::{GetType, LhsValue, LhsValueSeed, Type, TypeMismatchError},
ListMatcher,
ListMatcher, UnknownFieldError,
};
use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, Visitor};
use serde::ser::{SerializeMap, SerializeSeq, Serializer};
Expand All @@ -16,11 +16,15 @@ use thiserror::Error;
pub enum SetFieldValueError {
/// An error that occurs when trying to assign a value of the wrong type to a field.
#[error("{0}")]
TypeMismatchError(#[source] TypeMismatchError),
TypeMismatch(#[source] TypeMismatchError),

/// An error that occurs when trying to set the value of a field from a different scheme.
#[error("{0}")]
SchemeMismatchError(#[source] SchemeMismatchError),
SchemeMismatch(#[source] SchemeMismatchError),

/// An error that occurs when specifying an unknown field name.
#[error("{0}")]
UnknownField(#[source] UnknownFieldError),
}

/// An error that occurs when previously defined list gets redefined.
Expand Down Expand Up @@ -82,7 +86,7 @@ impl<'e, U> ExecutionContext<'e, U> {
value: V,
) -> Result<Option<LhsValue<'e>>, SetFieldValueError> {
if !std::ptr::eq(self.scheme, field.scheme()) {
return Err(SetFieldValueError::SchemeMismatchError(SchemeMismatchError));
return Err(SetFieldValueError::SchemeMismatch(SchemeMismatchError));
}
let value = value.into();

Expand All @@ -92,7 +96,32 @@ impl<'e, U> ExecutionContext<'e, U> {
if field_type == value_type {
Ok(self.values[field.index()].replace(value))
} else {
Err(SetFieldValueError::TypeMismatchError(TypeMismatchError {
Err(SetFieldValueError::TypeMismatch(TypeMismatchError {
expected: field_type.into(),
actual: value_type,
}))
}
}

/// Sets a runtime value for a given field name.
pub fn set_field_value_from_name<'v: 'e, V: Into<LhsValue<'v>>>(
&mut self,
name: &str,
value: V,
) -> Result<Option<LhsValue<'e>>, SetFieldValueError> {
let field = self
.scheme
.get_field(name)
.map_err(SetFieldValueError::UnknownField)?;
let value = value.into();

let field_type = field.get_type();
let value_type = value.get_type();

if field_type == value_type {
Ok(self.values[field.index()].replace(value))
} else {
Err(SetFieldValueError::TypeMismatch(TypeMismatchError {
expected: field_type.into(),
actual: value_type,
}))
Expand Down Expand Up @@ -138,13 +167,25 @@ impl<'e, U> ExecutionContext<'e, U> {
&*self.list_matchers[list.index()]
}

/// Get the list matcher object for the specified type.
pub fn get_list_matcher_from_type(&self, ty: &Type) -> Option<&dyn ListMatcher> {
let list = self.scheme.get_list(ty)?;
Some(&*self.list_matchers[list.index()])
}

/// Get the list matcher object for the specified list type.
pub fn get_list_matcher_mut(&mut self, list: List<'_>) -> &mut dyn ListMatcher {
assert!(self.scheme() == list.scheme());

&mut *self.list_matchers[list.index()]
}

/// Get the list matcher object for the specified type.
pub fn get_list_matcher_mut_from_type(&mut self, ty: &Type) -> Option<&mut dyn ListMatcher> {
let list = self.scheme.get_list(ty)?;
Some(&mut *self.list_matchers[list.index()])
}

/// Get immutable reference to user data stored in
/// this execution context with [`ExecutionContext::new_with`].
#[inline]
Expand Down Expand Up @@ -285,18 +326,18 @@ impl<'de, U> DeserializeSeed<'de> for &mut ExecutionContext<'de, U> {
.map_err(|_| de::Error::custom(format!("unknown field: {key}")))?;
let value = access
.next_value_seed::<LhsValueSeed<'_>>(LhsValueSeed(&field.get_type()))?;
let field = self
.0
.scheme()
.get_field(&key)
.map_err(|_| de::Error::custom(format!("unknown field: {key}")))?;
self.0.set_field_value(field, value).map_err(|e| match e {
SetFieldValueError::TypeMismatchError(e) => de::Error::custom(format!(
"invalid type: {:?}, expected {:?}",
e.actual, e.expected
)),
SetFieldValueError::SchemeMismatchError(_) => unreachable!(),
})?;
self.0
.set_field_value_from_name(&key, value)
.map_err(|e| match e {
SetFieldValueError::UnknownField(UnknownFieldError) => {
de::Error::custom(format!("unknown field name `{key}`",))
}
SetFieldValueError::TypeMismatch(e) => de::Error::custom(format!(
"invalid type: {:?}, expected {:?}",
e.actual, e.expected
)),
SetFieldValueError::SchemeMismatch(_) => unreachable!(),
})?;
}
}

Expand Down Expand Up @@ -361,7 +402,7 @@ fn test_field_value_type_mismatch() {

assert_eq!(
ctx.set_field_value(scheme.get_field("foo").unwrap(), LhsValue::Bool(false)),
Err(SetFieldValueError::TypeMismatchError(TypeMismatchError {
Err(SetFieldValueError::TypeMismatch(TypeMismatchError {
expected: Type::Int.into(),
actual: Type::Bool,
}))
Expand All @@ -378,9 +419,7 @@ fn test_scheme_mismatch() {

assert_eq!(
ctx.set_field_value(scheme2.get_field("foo").unwrap(), LhsValue::Bool(false)),
Err(SetFieldValueError::SchemeMismatchError(
SchemeMismatchError {}
))
Err(SetFieldValueError::SchemeMismatch(SchemeMismatchError {}))
);
}

Expand Down
Loading