@@ -83,35 +83,60 @@ def make_matrix(msgs, sigs, pubs, B, order, matrix_type="dense"):
8383 return matrix
8484
8585
86- def privkeys_from_reduced_matrix (msgs , sigs , pubs , matrix , order ):
87- keys = []
86+ def privkeys_from_reduced_matrix (msgs , sigs , pubs , matrix , order , max_rows = 20 ):
87+ """
88+ Extract private keys by:
89+ • Precomputing (a,b,cd,ab_list) for all msgs,
90+ • Sorting rows by ||row|| ascending,
91+ • Testing only the top `max_rows` rows.
92+ """
93+ from math import sqrt
94+ keys = set ()
95+ m = len (msgs )
8896 msgn , rn , sn = msgs [- 1 ], sigs [- 1 ][0 ], sigs [- 1 ][1 ]
8997
90- for i in range (len (msgs )):
98+ # 1) Precompute per-i constants
99+ params = []
100+ for i in range (m ):
91101 a = rn * sigs [i ][1 ]
92102 b = sn * sigs [i ][0 ]
93103 c = sn * msgs [i ]
94104 d = msgn * sigs [i ][1 ]
95- cd = c - d
96-
97- if a == b :
98- for row in matrix :
99- for j in range (len (msgs )):
100- potential_nonce_diff = row [j ]
101- key = (cd - (b * potential_nonce_diff )) % order
102- if key not in keys :
103- keys .append (key )
104- else :
105- for row in matrix :
106- for j in range (len (msgs )):
107- potential_nonce_diff = row [j ]
108- potential_priv_key = cd - (b * potential_nonce_diff )
109- for ab in [a - b , b - a ]:
105+ cd = (c - d ) % order
106+ if a == b : ab_list = None
107+ else : ab_list = [ (a - b ) % order , (b - a ) % order ]
108+ params .append ((b , cd , ab_list ))
109+
110+ # 2) Compute row norms once
111+ row_norms = []
112+ for idx , row in enumerate (matrix ):
113+ # only consider first m components for the norm
114+ norm2 = sum ((float (row [j ])** 2 for j in range (m )))
115+ row_norms .append ((norm2 , idx ))
116+ row_norms .sort ()
117+
118+ # 3) Only test top max_rows shortest rows
119+ for _ , ridx in row_norms [:max_rows ]:
120+ row = matrix [ridx ]
121+ # extract all potential k-diffs at once
122+ kdiffs = [int (row [j ]) for j in range (m )]
123+ # for each message i, attempt recovery
124+ for i , (b , cd , ab_list ) in enumerate (params ):
125+ base = (cd - b * kdiffs [i ])
126+ if ab_list is None :
127+ # special case a==b -> key = base
128+ if 0 < base < order : keys .add (base )
129+ else : keys .add (base % order )
130+ else :
131+ for ab in ab_list :
132+ # modular_inv only if ab != 0
133+ if ab :
110134 inv = modular_inv (ab , order )
111- key = (potential_priv_key * inv ) % order
112- if key not in keys :
113- keys .append (key )
114- return keys
135+ key = (base * inv )
136+ if 0 < key < order : keys .add (key )
137+ else : keys .add (key % order )
138+ return list (keys )
139+
115140
116141
117142def display_keys (keys ):
0 commit comments