@@ -83,35 +83,62 @@ def make_matrix(msgs, sigs, pubs, B, order, matrix_type="dense"):
8383 return matrix
8484
8585
86- def privkeys_from_reduced_matrix (msgs , sigs , pubs , matrix , order ):
87- keys = []
86+ def privkeys_from_reduced_matrix (msgs , sigs , pubs , matrix , order , max_rows = 20 ):
87+ """
88+ Extract private keys by:
89+ • Precomputing (a,b,cd,ab_list) for all msgs,
90+ • Sorting rows by ||row|| ascending,
91+ • Testing only the top `max_rows` rows.
92+ """
93+ from math import sqrt
94+ keys = set ()
95+ m = len (msgs )
8896 msgn , rn , sn = msgs [- 1 ], sigs [- 1 ][0 ], sigs [- 1 ][1 ]
8997
90- for i in range (len (msgs )):
98+ # 1) Precompute per-i constants
99+ params = []
100+ for i in range (m ):
91101 a = rn * sigs [i ][1 ]
92102 b = sn * sigs [i ][0 ]
93103 c = sn * msgs [i ]
94104 d = msgn * sigs [i ][1 ]
95- cd = c - d
96-
105+ cd = (c - d ) % order
97106 if a == b :
98- for row in matrix :
99- for j in range (len (msgs )):
100- potential_nonce_diff = row [j ]
101- key = (cd - (b * potential_nonce_diff )) % order
102- if key not in keys :
103- keys .append (key )
107+ ab_list = None
104108 else :
105- for row in matrix :
106- for j in range (len (msgs )):
107- potential_nonce_diff = row [j ]
108- potential_priv_key = cd - (b * potential_nonce_diff )
109- for ab in [a - b , b - a ]:
109+ ab_list = [ (a - b ) % order , (b - a ) % order ]
110+ params .append ((b , cd , ab_list ))
111+
112+ # 2) Compute row norms once
113+ row_norms = []
114+ for idx , row in enumerate (matrix ):
115+ # only consider first m components for the norm
116+ norm2 = sum ((float (row [j ])** 2 for j in range (m )))
117+ row_norms .append ((norm2 , idx ))
118+ row_norms .sort ()
119+
120+ # 3) Only test top max_rows shortest rows
121+ for _ , ridx in row_norms [:max_rows ]:
122+ row = matrix [ridx ]
123+ # extract all potential k-diffs at once
124+ kdiffs = [int (row [j ]) for j in range (m )]
125+ # for each message i, attempt recovery
126+ for i , (b , cd , ab_list ) in enumerate (params ):
127+ base = (cd - b * kdiffs [i ])
128+ if ab_list is None :
129+ # special case a==b -> key = base
130+ if 0 < base < order : keys .add (base )
131+ else : keys .add (base % order )
132+ else :
133+ for ab in ab_list :
134+ # modular_inv only if ab != 0
135+ if ab :
110136 inv = modular_inv (ab , order )
111- key = (potential_priv_key * inv ) % order
112- if key not in keys :
113- keys .append (key )
114- return keys
137+ key = (base * inv )
138+ if 0 < key < order : keys .add (key )
139+ else : keys .add (key % order )
140+ return list (keys )
141+
115142
116143
117144def display_keys (keys ):
0 commit comments