Skip to content

Commit d5b40b3

Browse files
authored
QR decomposition (#577)
* QR decomposition * Replace one map_direct with a merge chunks * Use multiple outputs for QR * Use recursion when R1 is too big to fit in memory * _merge_into_single_chunk * Add map_blocks_multiple_outputs utility function * Add tsqr * Add memory utilization test for qr * Enforce tall-and-skinny for qr * QR recursion improvements and test
1 parent a4020f3 commit d5b40b3

File tree

5 files changed

+260
-2
lines changed

5 files changed

+260
-2
lines changed

cubed/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,9 @@
315315
from .array_api.utility_functions import all, any
316316

317317
__all__ += ["all", "any"]
318+
319+
# extensions
320+
321+
from .array_api import linalg
322+
323+
__all__ += ["linalg"]

cubed/array_api/linalg.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from typing import NamedTuple
2+
3+
from cubed.array_api.array_object import Array
4+
from cubed.backend_array_api import namespace as nxp
5+
from cubed.core.ops import general_blockwise, map_direct, merge_chunks
6+
from cubed.utils import array_memory, get_item
7+
8+
9+
class QRResult(NamedTuple):
10+
Q: Array
11+
R: Array
12+
13+
14+
def qr(x, /, *, mode="reduced") -> QRResult:
15+
if x.ndim != 2:
16+
raise ValueError("qr requires x to have 2 dimensions.")
17+
18+
if mode != "reduced":
19+
raise ValueError("qr only supports mode='reduced'")
20+
21+
if x.numblocks[1] > 1:
22+
raise ValueError(
23+
"qr only supports tall-and-skinny (single column chunk) arrays. "
24+
"Consider rechunking so there is only a single column chunk."
25+
)
26+
27+
return tsqr(x)
28+
29+
30+
def tsqr(x) -> QRResult:
31+
"""Direct Tall-and-Skinny QR algorithm
32+
33+
From:
34+
35+
Direct QR factorizations for tall-and-skinny matrices in MapReduce architectures
36+
Austin R. Benson, David F. Gleich, James Demmel
37+
Proceedings of the IEEE International Conference on Big Data, 2013
38+
https://arxiv.org/abs/1301.1071
39+
"""
40+
41+
# follows Algorithm 2 from Benson et al
42+
Q1, R1 = _qr_first_step(x)
43+
44+
if _r1_is_too_big(R1):
45+
R1 = _rechunk_r1(R1)
46+
Q2, R2 = tsqr(R1)
47+
else:
48+
Q2, R2 = _qr_second_step(R1)
49+
50+
Q, R = _qr_third_step(Q1, Q2), R2
51+
52+
return QRResult(Q, R)
53+
54+
55+
def _qr_first_step(A):
56+
m, n = A.chunksize
57+
k, _ = A.numblocks
58+
59+
# Q1 has same shape and chunks as A
60+
R1_shape = (n * k, n)
61+
R1_chunks = ((n,) * k, (n,))
62+
# qr implementation creates internal array buffers
63+
extra_projected_mem = A.chunkmem * 4
64+
Q1, R1 = map_blocks_multiple_outputs(
65+
nxp.linalg.qr,
66+
A,
67+
shapes=[A.shape, R1_shape],
68+
dtypes=[nxp.float64, nxp.float64],
69+
chunkss=[A.chunks, R1_chunks],
70+
extra_projected_mem=extra_projected_mem,
71+
)
72+
return QRResult(Q1, R1)
73+
74+
75+
def _r1_is_too_big(R1):
76+
array_mem = array_memory(R1.dtype, R1.shape)
77+
# conservative values for max_mem (4 copies, doubled to give some slack)
78+
max_mem = (R1.spec.allowed_mem - R1.spec.reserved_mem) // (4 * 2)
79+
return array_mem > max_mem
80+
81+
82+
def _rechunk_r1(R1, split_every=4):
83+
# expand R1's chunk size in axis 0 so that new R1 will be smaller by factor of split_every
84+
if R1.numblocks[0] == 1:
85+
raise ValueError(
86+
"Can't expand R1 chunk size further. Try increasing allowed_mem"
87+
)
88+
chunks = (R1.chunksize[0] * split_every, R1.chunksize[1])
89+
return merge_chunks(R1, chunks=chunks)
90+
91+
92+
def _qr_second_step(R1):
93+
R1_single = _merge_into_single_chunk(R1)
94+
95+
Q2_shape = R1.shape
96+
Q2_chunks = Q2_shape # single chunk
97+
98+
n = R1.shape[1]
99+
R2_shape = (n, n)
100+
R2_chunks = R2_shape # single chunk
101+
# qr implementation creates internal array buffers
102+
extra_projected_mem = R1_single.chunkmem * 4
103+
Q2, R2 = map_blocks_multiple_outputs(
104+
nxp.linalg.qr,
105+
R1_single,
106+
shapes=[Q2_shape, R2_shape],
107+
dtypes=[nxp.float64, nxp.float64],
108+
chunkss=[Q2_chunks, R2_chunks],
109+
extra_projected_mem=extra_projected_mem,
110+
)
111+
return QRResult(Q2, R2)
112+
113+
114+
def _merge_into_single_chunk(x, split_every=4):
115+
# do a tree merge along first axis
116+
while x.numblocks[0] > 1:
117+
chunks = (x.chunksize[0] * split_every,) + x.chunksize[1:]
118+
x = merge_chunks(x, chunks)
119+
return x
120+
121+
122+
def _qr_third_step(Q1, Q2):
123+
m, n = Q1.chunksize
124+
k, _ = Q1.numblocks
125+
126+
Q1_shape = Q1.shape
127+
Q1_chunks = Q1.chunks
128+
129+
Q2_chunks = ((n,) * k, (n,))
130+
extra_projected_mem = 0
131+
Q = map_direct(
132+
_q_matmul,
133+
Q1,
134+
Q2,
135+
shape=Q1_shape,
136+
dtype=nxp.float64,
137+
chunks=Q1_chunks,
138+
extra_projected_mem=extra_projected_mem,
139+
q1_chunks=Q1_chunks,
140+
q2_chunks=Q2_chunks,
141+
)
142+
return Q
143+
144+
145+
def _q_matmul(x, *arrays, q1_chunks=None, q2_chunks=None, block_id=None):
146+
q1 = arrays[0].zarray[get_item(q1_chunks, block_id)]
147+
# this array only has a single chunk, but we need to get a slice corresponding to q2_chunks
148+
q2 = arrays[1].zarray[get_item(q2_chunks, block_id)]
149+
return q1 @ q2
150+
151+
152+
def map_blocks_multiple_outputs(
153+
func,
154+
*args,
155+
shapes,
156+
dtypes,
157+
chunkss,
158+
**kwargs,
159+
):
160+
def key_function(out_key):
161+
return tuple((array.name,) + out_key[1:] for array in args)
162+
163+
return general_blockwise(
164+
func,
165+
key_function,
166+
*args,
167+
shapes=shapes,
168+
dtypes=dtypes,
169+
chunkss=chunkss,
170+
target_stores=[None] * len(dtypes),
171+
**kwargs,
172+
)

cubed/core/plan.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def _finalize(
268268
compile_function: Optional[Decorator] = None,
269269
array_names=None,
270270
) -> "FinalizedPlan":
271-
dag = self.optimize(optimize_function, array_names).dag if optimize_graph else self.dag
271+
dag = (
272+
self.optimize(optimize_function, array_names).dag
273+
if optimize_graph
274+
else self.dag
275+
)
272276
# create a copy since _create_lazy_zarr_arrays mutates the dag
273277
dag = dag.copy()
274278
if callable(compile_function):
@@ -501,6 +505,10 @@ def num_arrays(self) -> int:
501505
"""Return the number of arrays in this plan."""
502506
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))
503507

508+
def num_primitive_ops(self) -> int:
509+
"""Return the number of primitive operations in this plan."""
510+
return len(list(visit_nodes(self.dag)))
511+
504512
def num_tasks(self, resume=None):
505513
"""Return the number of tasks needed to execute this plan."""
506514
tasks = 0

cubed/tests/test_linalg.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
import pytest
3+
from numpy.testing import assert_allclose
4+
5+
import cubed
6+
import cubed.array_api as xp
7+
from cubed.core.plan import arrays_to_plan
8+
9+
10+
def test_qr():
11+
A = np.reshape(np.arange(32, dtype=np.float64), (16, 2))
12+
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(4, 2)))
13+
14+
plan_unopt = arrays_to_plan(Q, R)._finalize()
15+
assert plan_unopt.num_primitive_ops() == 4
16+
17+
Q, R = cubed.compute(Q, R)
18+
19+
assert_allclose(Q @ R, A, atol=1e-08)
20+
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
21+
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular
22+
23+
24+
def test_qr_recursion():
25+
A = np.reshape(np.arange(128, dtype=np.float64), (64, 2))
26+
27+
# find a memory setting where recursion happens
28+
found = False
29+
for factor in range(4, 16):
30+
spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0)
31+
32+
try:
33+
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(8, 2), spec=spec))
34+
35+
found = True
36+
plan_unopt = arrays_to_plan(Q, R)._finalize()
37+
assert plan_unopt.num_primitive_ops() > 4 # more than without recursion
38+
39+
Q, R = cubed.compute(Q, R)
40+
41+
assert_allclose(Q @ R, A, atol=1e-08)
42+
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
43+
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular
44+
45+
break
46+
47+
except ValueError:
48+
pass # not enough memory
49+
50+
assert found
51+
52+
53+
def test_qr_chunking():
54+
A = xp.ones((32, 4), chunks=(4, 2))
55+
with pytest.raises(
56+
ValueError,
57+
match=r"qr only supports tall-and-skinny \(single column chunk\) arrays.",
58+
):
59+
xp.linalg.qr(A)

cubed/tests/test_mem_utilization.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,19 @@ def test_sum_partial_reduce(tmp_path, spec, executor):
330330
run_operation(tmp_path, executor, "sum_partial_reduce", b)
331331

332332

333+
# Linear algebra extension
334+
335+
336+
@pytest.mark.slow
337+
def test_qr(tmp_path, spec, executor):
338+
a = cubed.random.random(
339+
(40000, 1000), chunks=(5000, 1000), spec=spec
340+
) # 40MB chunks
341+
q, r = xp.linalg.qr(a)
342+
# don't optimize graph so we use as much memory as possible (reading from Zarr)
343+
run_operation(tmp_path, executor, "qr", q, r, optimize_graph=False)
344+
345+
333346
# Multiple outputs
334347

335348

@@ -362,7 +375,7 @@ def run_operation(
362375
# )
363376
hist = HistoryCallback()
364377
mem_warn = MemoryWarningCallback()
365-
memray = MemrayCallback()
378+
memray = MemrayCallback(mem_threshold=30_000_000)
366379
# use None for each store to write to temporary zarr
367380
cubed.store(
368381
results,

0 commit comments

Comments
 (0)