Skip to content

Commit c7e29fc

Browse files
authored
Add Moore-Penrose pseudo inverse and fast non-negative least squares (#455)
* Add Moore-Penrose pseudo inverse and fast non-negative least squares As we look into to Kubelka-Munk, it became apparent we need some ways to solve least squares. Moore-Penrose currently just does basic inverse of tall and wide, but real tricky matrices will likely not invert and would require a much more advanced approach. Hopefully we won't need such an approach. pinv can produce negative values, so if you need a non-negative least squares, fnnls will do the trick though it is more costly.
1 parent d46e67a commit c7e29fc

File tree

4 files changed

+191
-10
lines changed

4 files changed

+191
-10
lines changed

coloraide/algebra.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def clamp(
159159
return value
160160

161161

162-
def zdiv(a: float, b: float) -> float:
162+
def zdiv(a: float, b: float, default: float = 0.0) -> float:
163163
"""Protect against zero divide."""
164164

165165
if b == 0:
166-
return 0.0
166+
return default
167167
return a / b
168168

169169

@@ -3592,6 +3592,29 @@ def inv(matrix: MatrixLike | TensorLike) -> Matrix | Tensor:
35923592
return _back_sub_matrix(u, _forward_sub_matrix(l, p, s2), s2)
35933593

35943594

3595+
def pinv(a: MatrixLike) -> Matrix:
3596+
"""
3597+
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
3598+
3599+
We currently 'assume' the matrix if full rank. If not, a singular matrix error
3600+
will be thrown. Such matrices may still be invertible, but they would require
3601+
a more advanced approach that we do not currently implement.
3602+
3603+
Negative results can be returned, use `fnnls` for a non-negative solution (if possible).
3604+
"""
3605+
3606+
s = shape(a)
3607+
if len(s) != 2:
3608+
raise ValueError('Inputs can only be matrices, vectors or tensors are not allowed')
3609+
3610+
t = transpose(a)
3611+
if s[0] >= s[1]:
3612+
p = matmul(inv(matmul(t, a, dims=D2)), t, dims=D2)
3613+
else:
3614+
p = matmul(t, inv(matmul(a, t, dims=D2)), dims=D2)
3615+
return p
3616+
3617+
35953618
@overload
35963619
def vstack(arrays: Sequence[float | Vector | Matrix]) -> Matrix:
35973620
...
@@ -3866,3 +3889,88 @@ def inner(a: float | ArrayLike, b: float | ArrayLike) -> float | Array:
38663889

38673890
# Shape the data.
38683891
return reshape(m, new_shape) # type: ignore[no-any-return]
3892+
3893+
3894+
def fnnls(
3895+
A: MatrixLike,
3896+
b: VectorLike,
3897+
epsilon: float = 1e-12,
3898+
max_iters: int = 0
3899+
) -> tuple[Vector, float]:
3900+
"""
3901+
Fast non-negative least squares.
3902+
3903+
A fast non-negativity-constrained least squares
3904+
https://www.researchgate.net/publication/230554373_A_Fast_Non-negativity-constrained_Least_Squares_Algorithm
3905+
Rasmus Bro and Sijmen De Jong
3906+
Journal of Chemometrics. 11, 393–401 (1997)
3907+
"""
3908+
3909+
n = len(A[0])
3910+
3911+
if not max_iters:
3912+
max_iters = n * 30
3913+
3914+
ATA = dot(transpose(A), A, dims=D2)
3915+
ATb = dot(transpose(A), b, dims=D2_D1)
3916+
3917+
x = zeros(n) # type: Vector # type: ignore[assignment]
3918+
s = zeros(n) # type: Vector # type: ignore[assignment]
3919+
w = subtract(ATb, dot(ATA, x, dims=D2_D1), dims=D1) # type: Vector
3920+
3921+
# P tracks positive elements in x
3922+
P = [False] * n # type: VectorBool
3923+
3924+
# Continue until all values of x are positive (non-negative results only)
3925+
# or we exhaust the iterations.
3926+
count = 0
3927+
while sum(P) < n and max(w[_i] for _i in range(n) if not P[_i]) > epsilon and count < max_iters:
3928+
# Find the index that maximizes w
3929+
# This will be an index not in P
3930+
imx = 0
3931+
mx = float('-inf')
3932+
for _i in range(n):
3933+
if not P[_i]:
3934+
temp = w[_i]
3935+
if temp > mx:
3936+
imx = _i
3937+
mx = temp
3938+
P[imx] = True
3939+
3940+
# Solve least squares problem for columns and rows not in P
3941+
idx = [_i for _i in range(n) if P[_i]]
3942+
v = dot(inv([[ATA[_i][_j] for _j in idx] for _i in idx]), [ATb[_i] for _i in idx], dims=D2_D1)
3943+
for _i, _v in zip(idx, v):
3944+
s[_i] = _v
3945+
3946+
# Deal with negative values
3947+
while _any([s[_i] <= epsilon for _i in range(n) if P[_i]]):
3948+
count += 1
3949+
3950+
# Calculate step size, alpha, to prevent any x from going negative
3951+
alpha = min(
3952+
[zdiv(x[_i], (x[_i] - s[_i]), float('inf')) for _i in range(n) if P[_i] * s[_i] <= epsilon]
3953+
)
3954+
3955+
# Update the solution
3956+
x = add(x, dot(alpha, subtract(s, x, dims=D1), dims=SC_D1), dims=D1)
3957+
3958+
# Remove indexes in P where x == 0
3959+
for _i in range(n):
3960+
if x[_i] <= epsilon:
3961+
P[_i] = False
3962+
3963+
# Solve least squares problem again
3964+
idx = [_i for _i in range(n) if P[_i]]
3965+
v = dot(inv([[ATA[_i][_j] for _j in idx] for _i in idx]), [ATb[_i] for _i in idx], dims=D2_D1)
3966+
s = [0.0] * len(s)
3967+
for _i, _v in zip(idx, v):
3968+
s[_i] = _v
3969+
3970+
# Update the solution
3971+
x = s[:]
3972+
w = subtract(ATb, dot(ATA, x, dims=D2_D1), dims=D1)
3973+
3974+
# Return our final result, for better or for worse
3975+
res = math.hypot(*subtract(b, dot(A, x, dims=D2_D1), dims=D1)) # ||b-Ax||
3976+
return x, res

docs/src/dictionary/en-custom.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ CVD
3737
CVDs
3838
Catmull
3939
Changelog
40+
Chemometrics
4041
Chroma
4142
Chromaticities
4243
Chromaticity
@@ -51,6 +52,7 @@ Cubehelix
5152
Culori
5253
Cz
5354
DCI
55+
De
5456
Deprecations
5557
Deregister
5658
Deregistration
@@ -95,6 +97,7 @@ IPT
9597
ITP
9698
ITU
9799
IgPgTg
100+
Illuminant
98101
Interpolator
99102
Itten
100103
Iz
@@ -105,6 +108,7 @@ JMh
105108
JND
106109
JSON
107110
Jacobian
111+
Jong
108112
Jsh
109113
Judd
110114
Jupyter
@@ -113,10 +117,13 @@ JzCzhz
113117
JzMzhz
114118
Jzazbz
115119
Kries
120+
Kubelka
116121
Kz
117122
LCh
118123
LChish
119124
LChuv
125+
LHTSS
126+
LLSS
120127
LMS
121128
Lab
122129
Labish
@@ -135,6 +142,7 @@ MkDocs
135142
Mollon
136143
Monochromacy
137144
Moroney
145+
Munk
138146
Mz
139147
NONINFRINGEMENT
140148
NaN
@@ -153,6 +161,7 @@ Oklrab
153161
Ostrowski
154162
Ottosson
155163
PQ
164+
Penrose
156165
Perceptibility
157166
Piecewise
158167
Planckian
@@ -177,6 +186,7 @@ RLAB
177186
ROMM
178187
RYB
179188
Raphson
189+
Rasmus
180190
SCD
181191
SDR
182192
SL
@@ -186,6 +196,7 @@ SVG
186196
Safdar
187197
Scalable
188198
Sharma
199+
Sijmen
189200
Sz
190201
TODO
191202
TORTIOUS
@@ -210,6 +221,7 @@ Vos
210221
Vz
211222
WCAG
212223
WCG
224+
Wijnen
213225
Wz
214226
XD
215227
XYB
@@ -256,16 +268,19 @@ desaturated
256268
deuteranomaly
257269
deuteranopia
258270
dichromacy
271+
differencing
259272
diffuser
260273
discretized
261274
docstring
262275
dyad
263276
easings
277+
emissive
264278
fixup
265279
formatter
266280
grayscale
267281
helixes
268282
hz
283+
illum
269284
illuminance
270285
illuminant
271286
illuminants
@@ -290,6 +305,9 @@ monochromacy
290305
monotonicity
291306
nd
292307
nm
308+
normalizations
309+
nx
310+
nxn
293311
oRGB
294312
opRGB
295313
opto
@@ -316,6 +334,7 @@ quantized
316334
quantizer
317335
rc
318336
reflectance
337+
reflectances
319338
repurpose
320339
rgb
321340
sRGB

pyproject.toml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,24 @@ lint.select = [
105105
]
106106

107107
lint.ignore = [
108-
"E741",
108+
"B905",
109109
"D202",
110-
"D401",
111-
"D212",
112110
"D203",
113-
"N802",
111+
"D212",
112+
"D401",
113+
"E741",
114114
"N801",
115+
"N802",
115116
"N803",
116117
"N806",
117118
"N818",
118-
"RUF012",
119-
"RUF005",
120119
"PGH004",
120+
"RUF002",
121+
"RUF005",
122+
"RUF012",
121123
"RUF022",
122124
"RUF023",
123-
"RUF100",
124-
"B905"
125+
"RUF100"
125126
]
126127

127128
[tool.coverage.report]

tests/test_algebra.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2613,6 +2613,59 @@ def dx2(x):
26132613
# Ostrowski
26142614
self.assertEqual(alg.solve_newton(1, f0, dx, ostrowski=True), (0.5, True))
26152615

2616+
def test_pinv(self):
2617+
"""Test Moore-Penrose pseudo inverse."""
2618+
2619+
m = [
2620+
[0.4123907992659593, 0.3575843393838777, 0.1804807884018343],
2621+
[0.21263900587151033, 0.7151686787677553, 0.07219231536073373],
2622+
[0.019330818715591832, 0.11919477979462595, 0.9505321522496605]
2623+
]
2624+
2625+
v = [0.047770200571454854, 0.02780940276126581, 0.22476064520055364]
2626+
2627+
# Negative results can be returned
2628+
result = alg.dot(alg.pinv(m), v)
2629+
self.assertEqual(result, [-5.551115123125783e-17, 0.015208514422912689, 0.23455058216100527])
2630+
2631+
wide = alg.pinv([[4, 5], [3, 3], [9, 7]])
2632+
self.assertEqual(alg.dot(wide, [3, 5, 6]), [0.29640718562873936, 0.538922155688625])
2633+
2634+
tall = alg.pinv([[4, 5, 3], [9, 7, 3]])
2635+
self.assertEqual(alg.dot(tall, [3, 5]), [0.2872727272727278, 0.2818181818181821, 0.14727272727272767])
2636+
2637+
with self.assertRaises(ValueError):
2638+
alg.pinv([1, 2, 3])
2639+
2640+
def test_fnnls(self):
2641+
"""Test fast non-negative least squares method."""
2642+
2643+
m = [
2644+
[0.4123907992659593, 0.3575843393838777, 0.1804807884018343],
2645+
[0.21263900587151033, 0.7151686787677553, 0.07219231536073373],
2646+
[0.019330818715591832, 0.11919477979462595, 0.9505321522496605]
2647+
]
2648+
2649+
v = [0.047770200571454854, 0.02780940276126581, 0.22476064520055364]
2650+
2651+
res = alg.fnnls(m, v)
2652+
b = alg.dot(alg.pinv(m), v)
2653+
2654+
# We should have no negative values, but we should be close to the `pinv` approach.
2655+
self.assertTrue(all(_a >= 0 for _a in res[0]))
2656+
self.assertTrue(res[1] < 1e-10)
2657+
self.assertTrue(all(math.isclose(_a, _b, rel_tol=1e-10, abs_tol=1e-11) for _a, _b in zip(res[0], b)))
2658+
2659+
# This is purposely beyond the range of a reasonable solution
2660+
# There will be residual
2661+
v = [0.6369580483012911, 0.262700212011267, 4.994106574466074e-17]
2662+
res = alg.fnnls(m, v)
2663+
2664+
# We should have no negative values, but we will have residual
2665+
self.assertFalse(res[1] < 1e-10)
2666+
self.assertTrue(all(_a >= 0 for _a in res[0]))
2667+
self.assertEqual(res[0], [1.477061311287275, 0.0, 0.0])
2668+
26162669

26172670
def test_pprint(capsys):
26182671
"""Test matrix print."""

0 commit comments

Comments
 (0)