@@ -368,6 +368,126 @@ mod tests {
368
368
) ;
369
369
}
370
370
371
+ #[ test]
372
+ #[ rustfmt:: skip]
373
+ /// This method tests that our EnumAccess workaround does not violate any memory safety rules.
374
+ /// In particular we want to make sure we avoid transmuting to any type that is not a u8 (we do
375
+ /// not support > 255 variants anyway).
376
+ fn test_serde_enum_access_behaviour ( ) {
377
+ use serde:: Deserialize ;
378
+ use serde:: Serialize ;
379
+
380
+ // Small-sized enums should all deserialize safely as single u8.
381
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
382
+ enum Singleton { A }
383
+
384
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
385
+ enum Pair { A , B }
386
+
387
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
388
+ enum Triple { A , B , C }
389
+
390
+ // Intentionally numbered enums with primitive representation (as long as u8) are safe.
391
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
392
+ enum CustomIndices {
393
+ A = 33 ,
394
+ B = 55 ,
395
+ C = 255 ,
396
+ }
397
+
398
+ // Complex enum's should still serialize as u8, and we expect the serde EnumAccess to work
399
+ // the same.
400
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
401
+ enum Complex {
402
+ A ,
403
+ B ( u8 , u8 ) ,
404
+ C { a : u8 , b : u8 } ,
405
+ }
406
+
407
+ // Forces the compiler to use a 16-bit discriminant. This must force the serde EnumAccess
408
+ // implementation to return an error. Otherwise we run the risk of the __Field enum in our
409
+ // transmute workaround becoming trash memory leading to UB.
410
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
411
+ enum ManyVariants {
412
+ _000, _001, _002, _003, _004, _005, _006, _007, _008, _009, _00A, _00B, _00C, _00D,
413
+ _00E, _00F, _010, _011, _012, _013, _014, _015, _016, _017, _018, _019, _01A, _01B,
414
+ _01C, _01D, _01E, _01F, _020, _021, _022, _023, _024, _025, _026, _027, _028, _029,
415
+ _02A, _02B, _02C, _02D, _02E, _02F, _030, _031, _032, _033, _034, _035, _036, _037,
416
+ _038, _039, _03A, _03B, _03C, _03D, _03E, _03F, _040, _041, _042, _043, _044, _045,
417
+ _046, _047, _048, _049, _04A, _04B, _04C, _04D, _04E, _04F, _050, _051, _052, _053,
418
+ _054, _055, _056, _057, _058, _059, _05A, _05B, _05C, _05D, _05E, _05F, _060, _061,
419
+ _062, _063, _064, _065, _066, _067, _068, _069, _06A, _06B, _06C, _06D, _06E, _06F,
420
+ _070, _071, _072, _073, _074, _075, _076, _077, _078, _079, _07A, _07B, _07C, _07D,
421
+ _07E, _07F, _080, _081, _082, _083, _084, _085, _086, _087, _088, _089, _08A, _08B,
422
+ _08C, _08D, _08E, _08F, _090, _091, _092, _093, _094, _095, _096, _097, _098, _099,
423
+ _09A, _09B, _09C, _09D, _09E, _09F, _0A0, _0A1, _0A2, _0A3, _0A4, _0A5, _0A6, _0A7,
424
+ _0A8, _0A9, _0AA, _0AB, _0AC, _0AD, _0AE, _0AF, _0B0, _0B1, _0B2, _0B3, _0B4, _0B5,
425
+ _0B6, _0B7, _0B8, _0B9, _0BA, _0BB, _0BC, _0BD, _0BE, _0BF, _0C0, _0C1, _0C2, _0C3,
426
+ _0C4, _0C5, _0C6, _0C7, _0C8, _0C9, _0CA, _0CB, _0CC, _0CD, _0CE, _0CF, _0D0, _0D1,
427
+ _0D2, _0D3, _0D4, _0D5, _0D6, _0D7, _0D8, _0D9, _0DA, _0DB, _0DC, _0DD, _0DE, _0DF,
428
+ _0E0, _0E1, _0E2, _0E3, _0E4, _0E5, _0E6, _0E7, _0E8, _0E9, _0EA, _0EB, _0EC, _0ED,
429
+ _0EE, _0EF, _0F0, _0F1, _0F2, _0F3, _0F4, _0F5, _0F6, _0F7, _0F8, _0F9, _0FA, _0FB,
430
+ _0FC, _0FD, _0FE, _0FF,
431
+
432
+ // > 255
433
+ _100
434
+ }
435
+
436
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
437
+ struct AllValid {
438
+ singleton : Singleton ,
439
+ pair : Pair ,
440
+ triple : Triple ,
441
+ complex : Complex ,
442
+ custom : CustomIndices ,
443
+ }
444
+
445
+ #[ derive( PartialEq , Serialize , Deserialize , Debug ) ]
446
+ struct Invalid {
447
+ many_variants : ManyVariants ,
448
+ }
449
+
450
+ let valid_buffer = [
451
+ // Singleton (A)
452
+ 0 ,
453
+ // Pair (B)
454
+ 1 ,
455
+ // Triple (C)
456
+ 2 ,
457
+ // Complex
458
+ 1 , 0 , 0 ,
459
+ // Custom
460
+ 2 ,
461
+ ] ;
462
+
463
+ let valid_struct = AllValid {
464
+ singleton : Singleton :: A ,
465
+ pair : Pair :: B ,
466
+ triple : Triple :: C ,
467
+ complex : Complex :: B ( 0 , 0 ) ,
468
+ custom : CustomIndices :: C ,
469
+ } ;
470
+
471
+ let valid_serialized = crate :: wire:: ser:: to_vec :: < _ , byteorder:: BE > ( & valid_struct) . unwrap ( ) ;
472
+
473
+ // Confirm that the valid buffer can be deserialized.
474
+ let valid = crate :: wire:: from_slice :: < byteorder:: BE , AllValid > ( & valid_buffer) . unwrap ( ) ;
475
+ let valid_deserialized = crate :: wire:: from_slice :: < byteorder:: BE , AllValid > ( & valid_serialized) . unwrap ( ) ;
476
+ assert_eq ! ( valid, valid_struct) ;
477
+ assert_eq ! ( valid_deserialized, valid_struct) ;
478
+
479
+ // Invalid buffer tests that types > u8 fail to deserialize, it's important to note that
480
+ // there is nothing stopping someone compiling a program with an invalid enum deserialize
481
+ // but we can at least ensure an error in deserialization occurs.
482
+ let invalid_buffer = [
483
+ // ManyVariants (256)
484
+ 1 , 0
485
+ ] ;
486
+
487
+ let result = crate :: wire:: from_slice :: < byteorder:: BE , Invalid > ( & invalid_buffer) ;
488
+ assert ! ( result. is_err( ) ) ;
489
+ }
490
+
371
491
// Test if the AccumulatorUpdateData type can be serialized and deserialized
372
492
// and still be the same as the original.
373
493
#[ test]
0 commit comments