@@ -23,6 +23,7 @@ class ModExpInput(TestParameterGroup):
23
23
exponent : Bytes
24
24
modulus : Bytes
25
25
extra_data : Bytes = Field (default_factory = Bytes )
26
+ raw_input : Bytes | None = None
26
27
27
28
@property
28
29
def length_base (self ) -> Bytes :
@@ -41,6 +42,8 @@ def length_modulus(self) -> Bytes:
41
42
42
43
def __bytes__ (self ):
43
44
"""Generate input for the MODEXP precompile."""
45
+ if self .raw_input is not None :
46
+ return self .raw_input
44
47
return (
45
48
self .length_base
46
49
+ self .length_exponent
@@ -60,20 +63,28 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput":
60
63
"""
61
64
if isinstance (input_data , str ):
62
65
input_data = Bytes (input_data )
63
- base_length = int .from_bytes (input_data [0 :32 ], byteorder = "big" )
64
- exponent_length = int .from_bytes (input_data [32 :64 ], byteorder = "big" )
65
- modulus_length = int .from_bytes (input_data [64 :96 ], byteorder = "big" )
66
+ assert not isinstance (input_data , str )
67
+ padded_input_data = input_data
68
+ if len (padded_input_data ) < 96 :
69
+ padded_input_data = Bytes (padded_input_data .ljust (96 , b"\0 " ))
70
+ base_length = int .from_bytes (padded_input_data [0 :32 ], byteorder = "big" )
71
+ exponent_length = int .from_bytes (padded_input_data [32 :64 ], byteorder = "big" )
72
+ modulus_length = int .from_bytes (padded_input_data [64 :96 ], byteorder = "big" )
73
+
74
+ total_required_length = 96 + base_length + exponent_length + modulus_length
75
+ if len (padded_input_data ) < total_required_length :
76
+ padded_input_data = Bytes (padded_input_data .ljust (total_required_length , b"\0 " ))
66
77
67
78
current_index = 96
68
- base = input_data [current_index : current_index + base_length ]
79
+ base = padded_input_data [current_index : current_index + base_length ]
69
80
current_index += base_length
70
81
71
- exponent = input_data [current_index : current_index + exponent_length ]
82
+ exponent = padded_input_data [current_index : current_index + exponent_length ]
72
83
current_index += exponent_length
73
84
74
- modulus = input_data [current_index : current_index + modulus_length ]
85
+ modulus = padded_input_data [current_index : current_index + modulus_length ]
75
86
76
- return cls (base = base , exponent = exponent , modulus = modulus )
87
+ return cls (base = base , exponent = exponent , modulus = modulus , raw_input = input_data )
77
88
78
89
79
90
class ModExpOutput (TestParameterGroup ):
0 commit comments