Skip to content

Commit 7c037ed

Browse files
committed
fix: avoid EnumAccess deserialization paths
1 parent ee6a744 commit 7c037ed

File tree

2 files changed

+167
-15
lines changed

2 files changed

+167
-15
lines changed

pythnet/pythnet_sdk/src/wire.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,126 @@ mod tests {
368368
);
369369
}
370370

371+
#[test]
372+
#[rustfmt::skip]
373+
/// This method tests that our EnumAccess workaround does not violate any memory safety rules.
374+
/// In particular we want to make sure we avoid transmuting to any type that is not a u8 (we do
375+
/// not support > 255 variants anyway).
376+
fn test_serde_enum_access_behaviour() {
377+
use serde::Deserialize;
378+
use serde::Serialize;
379+
380+
// Small-sized enums should all deserialize safely as single u8.
381+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
382+
enum Singleton { A }
383+
384+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
385+
enum Pair { A, B }
386+
387+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
388+
enum Triple { A, B, C }
389+
390+
// Intentionally numbered enums with primitive representation (as long as u8) are safe.
391+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
392+
enum CustomIndices {
393+
A = 33,
394+
B = 55,
395+
C = 255,
396+
}
397+
398+
// Complex enum's should still serialize as u8, and we expect the serde EnumAccess to work
399+
// the same.
400+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
401+
enum Complex {
402+
A,
403+
B(u8, u8),
404+
C { a: u8, b: u8 },
405+
}
406+
407+
// Forces the compiler to use a 16-bit discriminant. This must force the serde EnumAccess
408+
// implementation to return an error. Otherwise we run the risk of the __Field enum in our
409+
// transmute workaround becoming trash memory leading to UB.
410+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
411+
enum ManyVariants {
412+
_000, _001, _002, _003, _004, _005, _006, _007, _008, _009, _00A, _00B, _00C, _00D,
413+
_00E, _00F, _010, _011, _012, _013, _014, _015, _016, _017, _018, _019, _01A, _01B,
414+
_01C, _01D, _01E, _01F, _020, _021, _022, _023, _024, _025, _026, _027, _028, _029,
415+
_02A, _02B, _02C, _02D, _02E, _02F, _030, _031, _032, _033, _034, _035, _036, _037,
416+
_038, _039, _03A, _03B, _03C, _03D, _03E, _03F, _040, _041, _042, _043, _044, _045,
417+
_046, _047, _048, _049, _04A, _04B, _04C, _04D, _04E, _04F, _050, _051, _052, _053,
418+
_054, _055, _056, _057, _058, _059, _05A, _05B, _05C, _05D, _05E, _05F, _060, _061,
419+
_062, _063, _064, _065, _066, _067, _068, _069, _06A, _06B, _06C, _06D, _06E, _06F,
420+
_070, _071, _072, _073, _074, _075, _076, _077, _078, _079, _07A, _07B, _07C, _07D,
421+
_07E, _07F, _080, _081, _082, _083, _084, _085, _086, _087, _088, _089, _08A, _08B,
422+
_08C, _08D, _08E, _08F, _090, _091, _092, _093, _094, _095, _096, _097, _098, _099,
423+
_09A, _09B, _09C, _09D, _09E, _09F, _0A0, _0A1, _0A2, _0A3, _0A4, _0A5, _0A6, _0A7,
424+
_0A8, _0A9, _0AA, _0AB, _0AC, _0AD, _0AE, _0AF, _0B0, _0B1, _0B2, _0B3, _0B4, _0B5,
425+
_0B6, _0B7, _0B8, _0B9, _0BA, _0BB, _0BC, _0BD, _0BE, _0BF, _0C0, _0C1, _0C2, _0C3,
426+
_0C4, _0C5, _0C6, _0C7, _0C8, _0C9, _0CA, _0CB, _0CC, _0CD, _0CE, _0CF, _0D0, _0D1,
427+
_0D2, _0D3, _0D4, _0D5, _0D6, _0D7, _0D8, _0D9, _0DA, _0DB, _0DC, _0DD, _0DE, _0DF,
428+
_0E0, _0E1, _0E2, _0E3, _0E4, _0E5, _0E6, _0E7, _0E8, _0E9, _0EA, _0EB, _0EC, _0ED,
429+
_0EE, _0EF, _0F0, _0F1, _0F2, _0F3, _0F4, _0F5, _0F6, _0F7, _0F8, _0F9, _0FA, _0FB,
430+
_0FC, _0FD, _0FE, _0FF,
431+
432+
// > 255
433+
_100
434+
}
435+
436+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
437+
struct AllValid {
438+
singleton: Singleton,
439+
pair: Pair,
440+
triple: Triple,
441+
complex: Complex,
442+
custom: CustomIndices,
443+
}
444+
445+
#[derive(PartialEq, Serialize, Deserialize, Debug)]
446+
struct Invalid {
447+
many_variants: ManyVariants,
448+
}
449+
450+
let valid_buffer = [
451+
// Singleton (A)
452+
0,
453+
// Pair (B)
454+
1,
455+
// Triple (C)
456+
2,
457+
// Complex
458+
1, 0, 0,
459+
// Custom
460+
2,
461+
];
462+
463+
let valid_struct = AllValid {
464+
singleton: Singleton::A,
465+
pair: Pair::B,
466+
triple: Triple::C,
467+
complex: Complex::B(0, 0),
468+
custom: CustomIndices::C,
469+
};
470+
471+
let valid_serialized = crate::wire::ser::to_vec::<_, byteorder::BE>(&valid_struct).unwrap();
472+
473+
// Confirm that the valid buffer can be deserialized.
474+
let valid = crate::wire::from_slice::<byteorder::BE, AllValid>(&valid_buffer).unwrap();
475+
let valid_deserialized = crate::wire::from_slice::<byteorder::BE, AllValid>(&valid_serialized).unwrap();
476+
assert_eq!(valid, valid_struct);
477+
assert_eq!(valid_deserialized, valid_struct);
478+
479+
// Invalid buffer tests that types > u8 fail to deserialize, it's important to note that
480+
// there is nothing stopping someone compiling a program with an invalid enum deserialize
481+
// but we can at least ensure an error in deserialization occurs.
482+
let invalid_buffer = [
483+
// ManyVariants (256)
484+
1, 0
485+
];
486+
487+
let result = crate::wire::from_slice::<byteorder::BE, Invalid>(&invalid_buffer);
488+
assert!(result.is_err());
489+
}
490+
371491
// Test if the AccumulatorUpdateData type can be serialized and deserialized
372492
// and still be the same as the original.
373493
#[test]

pythnet/pythnet_sdk/src/wire/de.rs

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,27 @@
2323
//! https://serde.rs/impl-deserializer.html
2424
2525
use {
26+
crate::require,
2627
byteorder::{
2728
ByteOrder,
2829
ReadBytesExt,
2930
},
3031
serde::{
3132
de::{
3233
EnumAccess,
33-
IntoDeserializer,
3434
MapAccess,
3535
SeqAccess,
3636
VariantAccess,
3737
},
3838
Deserialize,
3939
},
40-
std::io::{
41-
Cursor,
42-
Seek,
43-
SeekFrom,
40+
std::{
41+
io::{
42+
Cursor,
43+
Seek,
44+
SeekFrom,
45+
},
46+
mem::size_of,
4447
},
4548
thiserror::Error,
4649
};
@@ -78,6 +81,9 @@ pub enum DeserializerError {
7881
#[error("message: {0}")]
7982
Message(Box<str>),
8083

84+
#[error("invalid enum variant, higher than expected variant range")]
85+
InvalidEnumVariant,
86+
8187
#[error("eof")]
8288
Eof,
8389
}
@@ -417,7 +423,7 @@ where
417423
fn deserialize_enum<V>(
418424
self,
419425
_name: &'static str,
420-
_variants: &'static [&'static str],
426+
variants: &'static [&'static str],
421427
visitor: V,
422428
) -> Result<V::Value, Self::Error>
423429
where
@@ -426,6 +432,10 @@ where
426432
// We read the discriminator here so that we can make the expected enum variant available
427433
// to the `visit_enum` call.
428434
let variant = self.cursor.read_u8().map_err(DeserializerError::from)?;
435+
if variant >= variants.len() as u8 {
436+
return Err(DeserializerError::InvalidEnumVariant);
437+
}
438+
429439
visitor.visit_enum(Enum { de: self, variant })
430440
}
431441

@@ -537,17 +547,39 @@ impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> {
537547
type Error = DeserializerError;
538548
type Variant = &'a mut Deserializer<'de, B>;
539549

540-
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
550+
fn variant_seed<V>(self, _: V) -> Result<(V::Value, Self::Variant), Self::Error>
541551
where
542552
V: serde::de::DeserializeSeed<'de>,
543553
{
544-
// This is a trick to get around Serde's expectation of a u32 for discriminants. The `seed`
545-
// here is the generated `Field` type from the `#[derive(Deserialize)]` macro. If we
546-
// attempt to deserialize this normally it will attempt to deserialize a u32 and fail.
547-
// Instead we take the already parsed variant and generate a deserializer for that which
548-
// will feed the u8 wire format into the deserialization logic, which overrides the default
549-
// deserialize call on the the `Field` type.
550-
seed.deserialize(self.variant.into_deserializer())
551-
.map(|v| (v, self.de))
554+
// When serializing/deserializing, serde passes a variant_index into the handlers. We
555+
// currently write these as u8's and have already parsed' them during deserialize_enum
556+
// before we reach this point.
557+
//
558+
// Normally, when deserializing enum tags from a wire format that does not match the
559+
// expected size, we would use a u*.into_deserializer() to feed the already parsed
560+
// result into the visit_u64 visitor method during `__Field` deserialize.
561+
//
562+
// The problem with this however is during `visit_u64`, there is a possibility the
563+
// enum variant is not valid, which triggers Serde to return an `Unexpected` error.
564+
// These errors have the unfortunate side effect of triggering Rust including float
565+
// operations in the resulting binary, which breaks WASM environments.
566+
//
567+
// To work around this, we rely on the following facts:
568+
//
569+
// - variant_index in Serde is always 0 indexed and contiguous
570+
// - transmute_copy into a 0 sized type is safe
571+
// - transmute_copy is safe to cast into __Field as long as u8 >= size_of::<__Field>()
572+
//
573+
// This behaviour relies on serde not changing its enum deserializer generation, but
574+
// this would be a major backwards compatibility break for them so we should be safe.
575+
require!(
576+
size_of::<u8>() >= size_of::<V::Value>(),
577+
DeserializerError::InvalidEnumVariant
578+
);
579+
580+
Ok((
581+
unsafe { std::mem::transmute_copy::<u8, V::Value>(&self.variant) },
582+
self.de,
583+
))
552584
}
553585
}

0 commit comments

Comments
 (0)