Skip to content

Commit 508a541

Browse files
committed
Handle OOM in WasmFuncType's serde support
1 parent 25e3bd1 commit 508a541

File tree

1 file changed

+171
-1
lines changed

1 file changed

+171
-1
lines changed

crates/environ/src/types.rs

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,14 +687,184 @@ pub enum WasmHeapBottomType {
687687
}
688688

689689
/// WebAssembly function type -- equivalent of `wasmparser`'s FuncType.
690-
#[derive(Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
690+
#[derive(Debug, Eq, PartialEq, Hash)]
691691
pub struct WasmFuncType {
692692
params_results: Box<[WasmValType]>,
693693
params_len: u32,
694694
non_i31_gc_ref_params_count: u32,
695695
non_i31_gc_ref_results_count: u32,
696696
}
697697

698+
impl serde::Serialize for WasmFuncType {
699+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
700+
where
701+
S: serde::Serializer,
702+
{
703+
use serde::ser::SerializeStruct as _;
704+
let mut s = serializer.serialize_struct("WasmFuncType", 2)?;
705+
s.serialize_field("params_results", &self.params_results)?;
706+
s.serialize_field("params_len", &self.params_len)?;
707+
s.end()
708+
}
709+
}
710+
711+
impl<'de> serde::Deserialize<'de> for WasmFuncType {
712+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
713+
where
714+
D: serde::Deserializer<'de>,
715+
{
716+
enum Field {
717+
ParamsResults,
718+
ParamsLen,
719+
}
720+
721+
const FIELDS: &[&str] = &["params_results", "params_len"];
722+
723+
impl<'de> serde::Deserialize<'de> for Field {
724+
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
725+
where
726+
D: serde::Deserializer<'de>,
727+
{
728+
struct FieldVisitor;
729+
730+
impl<'de> serde::de::Visitor<'de> for FieldVisitor {
731+
type Value = Field;
732+
733+
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
734+
f.write_str("`params_results` or `params_len`")
735+
}
736+
737+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
738+
where
739+
E: serde::de::Error,
740+
{
741+
match value {
742+
"params_results" => Ok(Field::ParamsResults),
743+
"params_len" => Ok(Field::ParamsLen),
744+
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
745+
}
746+
}
747+
}
748+
749+
deserializer.deserialize_identifier(FieldVisitor)
750+
}
751+
}
752+
753+
struct Visitor;
754+
755+
fn from_params_results_and_params_len<E>(
756+
params_results: crate::collections::Vec<WasmValType>,
757+
params_len: u64,
758+
) -> Result<WasmFuncType, E>
759+
where
760+
E: serde::de::Error,
761+
{
762+
let params_results = params_results
763+
.into_boxed_slice()
764+
.map_err(|oom| serde::de::Error::custom(oom))?;
765+
766+
struct ExpectedLen(usize);
767+
impl serde::de::Expected for ExpectedLen {
768+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
769+
write!(f, "<= {}", self.0)
770+
}
771+
}
772+
773+
let params_len = u32::try_from(params_len).map_err(|_| {
774+
serde::de::Error::invalid_value(
775+
serde::de::Unexpected::Unsigned(params_len),
776+
&ExpectedLen(params_results.len()),
777+
)
778+
})?;
779+
780+
let (non_i31_gc_ref_params_count, non_i31_gc_ref_results_count) = {
781+
let params_len = usize::try_from(params_len).unwrap();
782+
if params_len > params_results.len() {
783+
return Err(serde::de::Error::invalid_length(
784+
params_len,
785+
&ExpectedLen(params_results.len()),
786+
));
787+
}
788+
(
789+
u32::try_from(
790+
params_results[..params_len]
791+
.iter()
792+
.filter(|p| p.is_vmgcref_type_and_not_i31())
793+
.count(),
794+
)
795+
.unwrap(),
796+
u32::try_from(
797+
params_results[params_len..]
798+
.iter()
799+
.filter(|p| p.is_vmgcref_type_and_not_i31())
800+
.count(),
801+
)
802+
.unwrap(),
803+
)
804+
};
805+
806+
Ok(WasmFuncType {
807+
params_results,
808+
params_len,
809+
non_i31_gc_ref_params_count,
810+
non_i31_gc_ref_results_count,
811+
})
812+
}
813+
814+
impl<'de> serde::de::Visitor<'de> for Visitor {
815+
type Value = WasmFuncType;
816+
817+
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
818+
f.write_str("struct WasmFuncType")
819+
}
820+
821+
fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
822+
where
823+
V: serde::de::SeqAccess<'de>,
824+
{
825+
let params_results: crate::collections::Vec<WasmValType> = seq
826+
.next_element()?
827+
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
828+
let params_len: u64 = seq
829+
.next_element()?
830+
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
831+
from_params_results_and_params_len(params_results, params_len)
832+
}
833+
834+
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
835+
where
836+
V: serde::de::MapAccess<'de>,
837+
{
838+
let mut params_results: Option<crate::collections::Vec<WasmValType>> = None;
839+
let mut params_len: Option<u64> = None;
840+
while let Some(key) = map.next_key()? {
841+
match key {
842+
Field::ParamsResults => {
843+
if params_results.is_some() {
844+
return Err(serde::de::Error::duplicate_field("params_results"));
845+
}
846+
params_results = Some(map.next_value()?);
847+
}
848+
Field::ParamsLen => {
849+
if params_len.is_some() {
850+
return Err(serde::de::Error::duplicate_field("params_len"));
851+
}
852+
params_len = Some(map.next_value()?);
853+
}
854+
}
855+
}
856+
let params_results = params_results
857+
.ok_or_else(|| serde::de::Error::missing_field("params_results"))?;
858+
let params_len =
859+
params_len.ok_or_else(|| serde::de::Error::missing_field("params_len"))?;
860+
from_params_results_and_params_len(params_results, params_len)
861+
}
862+
}
863+
864+
deserializer.deserialize_struct("WasmFuncType", FIELDS, Visitor)
865+
}
866+
}
867+
698868
impl TryClone for WasmFuncType {
699869
fn try_clone(&self) -> Result<Self, OutOfMemory> {
700870
Ok(Self {

0 commit comments

Comments
 (0)