@@ -84,20 +84,15 @@ def _row_norms(A: csc_matrix, method: ScaleMethod) -> ndarray:
8484 Compute per-row magnitudes for a sparse matrix.
8585 """
8686 A_csr = A .tocsr ()
87- indptr = A_csr .indptr
88- data = np .abs (A_csr .data )
89- n_rows = A_csr .shape [0 ]
90- norms = np .ones (n_rows , dtype = float )
91-
92- for i in range (n_rows ):
93- start , end = indptr [i ], indptr [i + 1 ]
94- row_data = data [start :end ]
95- if row_data .size == 0 :
96- norms [i ] = 1.0
97- elif method == "row-l2" :
98- norms [i ] = np .sqrt (np .mean (row_data ** 2 ))
99- else :
100- norms [i ] = row_data .max ()
87+ if method == "row-l2" :
88+ norms = np .sqrt (np .array (A_csr .power (2 ).mean (axis = 1 )).ravel (), dtype = float )
89+ else :
90+ A_abs = A_csr .copy ()
91+ A_abs .data = np .abs (A_abs .data )
92+ norms = np .array (A_abs .max (axis = 1 ).toarray ()).ravel ().astype (float )
93+
94+ # rows without entries yield 0 or nan; keep them unscaled
95+ norms = np .where (np .isnan (norms ) | (norms == 0 ), 1.0 , norms )
10196 return norms
10297
10398
@@ -106,20 +101,14 @@ def _col_norms(A: csc_matrix, method: ScaleMethod) -> ndarray:
106101 Compute per-column magnitudes for a sparse matrix.
107102 """
108103 A_csc = A .tocsc ()
109- indptr = A_csc .indptr
110- data = np .abs (A_csc .data )
111- n_cols = A_csc .shape [1 ]
112- norms = np .ones (n_cols , dtype = float )
113-
114- for j in range (n_cols ):
115- start , end = indptr [j ], indptr [j + 1 ]
116- col_data = data [start :end ]
117- if col_data .size == 0 :
118- norms [j ] = 1.0
119- elif method == "row-l2" :
120- norms [j ] = np .sqrt (np .mean (col_data ** 2 ))
121- else :
122- norms [j ] = col_data .max ()
104+ if method == "row-l2" :
105+ norms = np .sqrt (np .array (A_csc .power (2 ).mean (axis = 0 )).ravel (), dtype = float )
106+ else :
107+ A_abs = A_csc .copy ()
108+ A_abs .data = np .abs (A_abs .data )
109+ norms = np .array (A_abs .max (axis = 0 ).toarray ()).ravel ().astype (float )
110+
111+ norms = np .where (np .isnan (norms ) | (norms == 0 ), 1.0 , norms )
123112 return norms
124113
125114
0 commit comments