@@ -23,6 +23,7 @@ class ModExpInput(TestParameterGroup):
23
23
exponent : Bytes
24
24
modulus : Bytes
25
25
extra_data : Bytes = Field (default_factory = Bytes )
26
+ raw_input : Bytes | None = None
26
27
27
28
@property
28
29
def length_base (self ) -> Bytes :
@@ -41,6 +42,8 @@ def length_modulus(self) -> Bytes:
41
42
42
43
def __bytes__ (self ):
43
44
"""Generate input for the MODEXP precompile."""
45
+ if self .raw_input is not None :
46
+ return self .raw_input
44
47
return (
45
48
self .length_base
46
49
+ self .length_exponent
@@ -60,22 +63,31 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput":
60
63
"""
61
64
if isinstance (input_data , str ):
62
65
input_data = Bytes (input_data )
63
- base_length = int .from_bytes (input_data [0 :32 ], byteorder = "big" )
64
- exponent_length = int .from_bytes (input_data [32 :64 ], byteorder = "big" )
65
- modulus_length = int .from_bytes (input_data [64 :96 ], byteorder = "big" )
66
+ assert not isinstance (input_data , str )
67
+ padded_input_data = input_data
68
+ if len (padded_input_data ) < 96 :
69
+ padded_input_data = Bytes (padded_input_data .ljust (96 , b"\0 " ))
70
+
71
+ base_length = int .from_bytes (padded_input_data [0 :32 ], byteorder = "big" )
72
+ exponent_length = int .from_bytes (padded_input_data [32 :64 ], byteorder = "big" )
73
+ modulus_length = int .from_bytes (padded_input_data [64 :96 ], byteorder = "big" )
74
+
75
+ total_required_length = 96 + base_length + exponent_length + modulus_length
76
+ if len (padded_input_data ) < total_required_length :
77
+ padded_input_data = Bytes (
78
+ padded_input_data .ljust (min (1024 , total_required_length ), b"\0 " )
79
+ )
66
80
67
81
current_index = 96
68
- base = input_data [current_index : current_index + base_length ]
82
+ base = padded_input_data [current_index : current_index + base_length ]
69
83
current_index += base_length
70
84
71
- exponent = input_data [current_index : current_index + exponent_length ]
85
+ exponent = padded_input_data [current_index : current_index + exponent_length ]
72
86
current_index += exponent_length
73
87
74
- modulus = input_data [current_index : current_index + modulus_length ]
75
-
76
- modulus = modulus .ljust (min (1024 , modulus_length ), b"\x00 " )
88
+ modulus = padded_input_data [current_index : current_index + modulus_length ]
77
89
78
- return cls (base = base , exponent = exponent , modulus = modulus )
90
+ return cls (base = base , exponent = exponent , modulus = modulus , raw_input = input_data )
79
91
80
92
81
93
class ModExpOutput (TestParameterGroup ):
0 commit comments