Skip to content

Commit 379357e

Browse files
committed
Vectorize scaling norm computation
1 parent 07dcad4 commit 379357e

File tree

1 file changed

+17
-28
lines changed

1 file changed

+17
-28
lines changed

linopy/scaling.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,15 @@ def _row_norms(A: csc_matrix, method: ScaleMethod) -> ndarray:
8484
Compute per-row magnitudes for a sparse matrix.
8585
"""
8686
A_csr = A.tocsr()
87-
indptr = A_csr.indptr
88-
data = np.abs(A_csr.data)
89-
n_rows = A_csr.shape[0]
90-
norms = np.ones(n_rows, dtype=float)
91-
92-
for i in range(n_rows):
93-
start, end = indptr[i], indptr[i + 1]
94-
row_data = data[start:end]
95-
if row_data.size == 0:
96-
norms[i] = 1.0
97-
elif method == "row-l2":
98-
norms[i] = np.sqrt(np.mean(row_data**2))
99-
else:
100-
norms[i] = row_data.max()
87+
if method == "row-l2":
88+
norms = np.sqrt(np.array(A_csr.power(2).mean(axis=1)).ravel(), dtype=float)
89+
else:
90+
A_abs = A_csr.copy()
91+
A_abs.data = np.abs(A_abs.data)
92+
norms = np.array(A_abs.max(axis=1).toarray()).ravel().astype(float)
93+
94+
# rows without entries yield 0 or nan; keep them unscaled
95+
norms = np.where(np.isnan(norms) | (norms == 0), 1.0, norms)
10196
return norms
10297

10398

@@ -106,20 +101,14 @@ def _col_norms(A: csc_matrix, method: ScaleMethod) -> ndarray:
106101
Compute per-column magnitudes for a sparse matrix.
107102
"""
108103
A_csc = A.tocsc()
109-
indptr = A_csc.indptr
110-
data = np.abs(A_csc.data)
111-
n_cols = A_csc.shape[1]
112-
norms = np.ones(n_cols, dtype=float)
113-
114-
for j in range(n_cols):
115-
start, end = indptr[j], indptr[j + 1]
116-
col_data = data[start:end]
117-
if col_data.size == 0:
118-
norms[j] = 1.0
119-
elif method == "row-l2":
120-
norms[j] = np.sqrt(np.mean(col_data**2))
121-
else:
122-
norms[j] = col_data.max()
104+
if method == "row-l2":
105+
norms = np.sqrt(np.array(A_csc.power(2).mean(axis=0)).ravel(), dtype=float)
106+
else:
107+
A_abs = A_csc.copy()
108+
A_abs.data = np.abs(A_abs.data)
109+
norms = np.array(A_abs.max(axis=0).toarray()).ravel().astype(float)
110+
111+
norms = np.where(np.isnan(norms) | (norms == 0), 1.0, norms)
123112
return norms
124113

125114

0 commit comments

Comments
 (0)