Skip to content

Commit a00d5a0

Browse files
committed
implement hierachical matrix solver
1 parent 0176d32 commit a00d5a0

File tree

5 files changed

+561
-8
lines changed

5 files changed

+561
-8
lines changed

pytential/linalg/cluster.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,23 @@ def make_block(i: int, j: int):
220220
else:
221221
raise ValueError(f"unsupported ndarray dimension: '{ndim}'")
222222

223+
224+
def uncluster(ary: np.ndarray, index: IndexList, clevel: ClusterLevel) -> np.ndarray:
225+
assert ary.shape == (clevel.parent_map.size,)
226+
227+
if index.nclusters == 1:
228+
return ary
229+
230+
result = np.empty(index.nclusters, dtype=object)
231+
for ifrom, ppm in enumerate(clevel.parent_map):
232+
offset = 0
233+
for ito in ppm:
234+
cluster_size = index.cluster_size(ito)
235+
result[ito] = ary[ifrom][offset:offset + cluster_size]
236+
offset += cluster_size
237+
238+
return result
239+
223240
# }}}
224241

225242

pytential/linalg/hmatrix.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
__copyright__ = "Copyright (C) 2022 Alexandru Fikl"
2+
3+
__license__ = """
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in
12+
all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
THE SOFTWARE.
21+
"""
22+
23+
from dataclasses import dataclass
24+
from typing import Any, Dict, Iterable, Optional, Union
25+
26+
import numpy as np
27+
import numpy.linalg as la
28+
29+
from arraycontext import PyOpenCLArrayContext, ArrayOrContainerT, flatten, unflatten
30+
from meshmode.dof_array import DOFArray
31+
32+
from pytential import GeometryCollection, sym
33+
from pytential.linalg.cluster import ClusterTree
34+
from pytential.linalg.skeletonization import SkeletonizationWrangler
35+
36+
__doc__ = """
37+
Hierarical Matrix Construction
38+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39+
"""
40+
41+
42+
# {{{ ProxyHierarchicalMatrix
43+
44+
@dataclass(frozen=True)
45+
class ProxyHierarchicalMatrix:
46+
"""
47+
.. attribute:: skeletons
48+
49+
An :class:`~numpy.ndarray` containing skeletonization information
50+
for each level of the hierarchy. For additional details, see
51+
:class:`~pytential.linalg.skeletonization.SkeletonizationResult`.
52+
53+
This class implements the :class:`scipy.sparse.linalg.LinearOperator`
54+
interface. In particular, the following attributes and methods:
55+
56+
.. attribute:: shape
57+
58+
A :class:`tuple` that gives the matrix size ``(m, n)``.
59+
60+
.. attribute:: dtype
61+
62+
The data type of the matrix entries.
63+
64+
.. automethod:: matvec
65+
.. automethod:: __matmul__
66+
"""
67+
68+
wrangler: SkeletonizationWrangler
69+
ctree: ClusterTree
70+
method: str
71+
72+
skeletons: np.ndarray
73+
74+
@property
75+
def shape(self):
76+
return self.skeletons[0].tgt_src_index.shape
77+
78+
@property
79+
def dtype(self):
80+
# FIXME: assert that everyone has this dtype?
81+
return self.skeletons[0].R[0].dtype
82+
83+
@property
84+
def nlevels(self):
85+
return self.skeletons.size
86+
87+
@property
88+
def nclusters(self):
89+
return self.skeletons[0].nclusters
90+
91+
def matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
92+
"""Implements a matrix-vector multiplication :math:`H x`."""
93+
from arraycontext import get_container_context_recursively_opt
94+
actx = get_container_context_recursively_opt(x)
95+
if actx is None:
96+
raise ValueError("input array is frozen")
97+
98+
if self.method == "forward":
99+
return apply_skeleton_matvec(actx, self, x)
100+
elif self.method == "backward":
101+
return apply_skeleton_inverse_matvec(actx, self, x)
102+
else:
103+
raise ValueError(f"unknown matvec method: '{self.method}'")
104+
105+
def __matmul__(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
106+
"""Same as :meth:`matvec`."""
107+
return self.matvec(x)
108+
109+
def rmatvec(self, x):
110+
raise NotImplementedError
111+
112+
def matmat(self, mat):
113+
raise NotImplementedError
114+
115+
def rmatmat(self, mat):
116+
raise NotImplementedError
117+
118+
119+
def apply_skeleton_matvec(
120+
actx: PyOpenCLArrayContext,
121+
hmat: ProxyHierarchicalMatrix,
122+
ary: ArrayOrContainerT,
123+
) -> ArrayOrContainerT:
124+
if not isinstance(ary, DOFArray):
125+
raise TypeError(f"unsupported array container: '{type(ary).__name__}'")
126+
127+
from pytential.linalg.utils import split_array
128+
targets, sources = hmat.skeletons[0].tgt_src_index
129+
x = split_array(actx.to_numpy(flatten(ary, actx)), sources)
130+
131+
# NOTE: this computes a telescoping product of the form
132+
#
133+
# A x_0 = (D0 + L0 (D1 + L1 (...) R1) R0) x_0
134+
#
135+
# with arbitrary numbers of levels. When recursing down, we compute
136+
#
137+
# x_{k + 1} = R_k x_k
138+
# z_{k + 1} = D_k x_k
139+
#
140+
# and, at the root level, we have
141+
#
142+
# x_{N + 1} = z_{N + 1} = D_N x_N.
143+
#
144+
# When recursing back up, we take `b_{N + 1} = x_{N + 1}` and
145+
#
146+
# b_{k - 1} = z_k + L_k b_k
147+
#
148+
# which gives back the desired product when we reach the leaf level again.
149+
150+
d_dot_x = np.empty(hmat.nlevels, dtype=object)
151+
152+
# {{{ recurse down
153+
154+
from pytential.linalg.cluster import cluster
155+
clevels = list(hmat.ctree.levels(root=True))
156+
157+
for k, clevel in enumerate(clevels):
158+
skeleton = hmat.skeletons[k]
159+
assert x.shape == (skeleton.nclusters,)
160+
assert skeleton.tgt_src_index.shape[1] == sum([xi.size for xi in x])
161+
162+
d_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
163+
r_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
164+
165+
for i in range(skeleton.nclusters):
166+
r_dot_x_k[i] = skeleton.R[i] @ x[i]
167+
d_dot_x_k[i] = skeleton.D[i] @ x[i]
168+
169+
d_dot_x[k] = d_dot_x_k
170+
x = cluster(r_dot_x_k, clevel)
171+
172+
# }}}
173+
174+
# {{{ root
175+
176+
# NOTE: at root level, we just multiply with the full diagonal
177+
b = d_dot_x[hmat.nlevels - 1]
178+
assert b.shape == (1,)
179+
180+
# }}}
181+
182+
# {{{ recurse up
183+
184+
from pytential.linalg.cluster import uncluster
185+
186+
for k, clevel in reversed(list(enumerate(clevels[:-1]))):
187+
skeleton = hmat.skeletons[k]
188+
d_dot_x_k = d_dot_x[k]
189+
assert d_dot_x_k.shape == (skeleton.nclusters,)
190+
191+
b = uncluster(b, skeleton.skel_tgt_src_index.targets, clevel)
192+
for i in range(skeleton.nclusters):
193+
b[i] = d_dot_x_k[i] + skeleton.L[i] @ b[i]
194+
195+
assert b.shape == (hmat.nclusters,)
196+
197+
# }}}
198+
199+
b = np.concatenate(b)[np.argsort(targets.indices)]
200+
return unflatten(ary, actx.from_numpy(b), actx)
201+
202+
203+
def apply_skeleton_inverse_matvec(
204+
actx: PyOpenCLArrayContext,
205+
hmat: ProxyHierarchicalMatrix,
206+
ary: ArrayOrContainerT,
207+
) -> ArrayOrContainerT:
208+
if not isinstance(ary, DOFArray):
209+
raise TypeError(f"unsupported array container: '{type(ary).__name__}'")
210+
211+
from pytential.linalg.utils import split_array
212+
targets, sources = hmat.skeletons[0].tgt_src_index
213+
214+
b = split_array(actx.to_numpy(flatten(ary, actx)), targets)
215+
inv_dhat_dot_b = np.empty(hmat.nlevels, dtype=object)
216+
217+
# {{{ recurse down
218+
219+
from pytential.linalg.cluster import cluster
220+
clevels = list(hmat.ctree.levels(root=True))
221+
222+
for k, clevel in enumerate(clevels):
223+
skeleton = hmat.skeletons[k]
224+
assert b.shape == (skeleton.nclusters,)
225+
assert skeleton.tgt_src_index.shape[0] == sum([bi.size for bi in b])
226+
227+
inv_d_dot_b_k = np.empty(skeleton.nclusters, dtype=object)
228+
inv_dhat_dot_b_k = np.empty(skeleton.nclusters, dtype=object)
229+
230+
for i in range(skeleton.nclusters):
231+
inv_dhat_dot_b_k[i] = (
232+
skeleton.Dhat[i] @ (skeleton.R[i] @ (skeleton.invD[i] @ b[i]))
233+
)
234+
inv_d_dot_b_k[i] = skeleton.invD[i] @ b[i]
235+
236+
inv_dhat_dot_b[k] = inv_dhat_dot_b_k
237+
b = cluster(inv_dhat_dot_b_k, clevel)
238+
239+
# }}}
240+
241+
# {{{ root
242+
243+
from pytools.obj_array import make_obj_array
244+
assert b.shape == (1,)
245+
x = make_obj_array([
246+
la.solve(D, bi) for D, bi in zip(hmat.skeletons[-1].D, b)
247+
])
248+
249+
# }}}
250+
251+
# {{{ recurse up
252+
253+
from pytential.linalg.cluster import uncluster
254+
255+
for k, clevel in reversed(list(enumerate(clevels[:-1]))):
256+
skeleton = hmat.skeletons[k]
257+
inv_dhat_dot_b_k0 = inv_dhat_dot_b[k]
258+
inv_dhat_dot_b_k1 = inv_dhat_dot_b[k + 1]
259+
assert inv_d_dot_b_k.shape == (skeleton.nclusters,)
260+
261+
x = uncluster(x, skeleton.skel_tgt_src_index.sources, clevel)
262+
inv_dhat_dot_b_k1 = uncluster(
263+
inv_dhat_dot_b_k1, skeleton.skel_tgt_src_index.sources, clevel)
264+
265+
for i in range(skeleton.nclusters):
266+
x[i] = skeleton.invD[i] @ (
267+
inv_dhat_dot_b_k0[i]
268+
- skeleton.L[i] @ inv_dhat_dot_b_k1[i]
269+
+ skeleton.L[i] @ (skeleton.Dhat @ x[i])
270+
)
271+
272+
assert x.shape == (hmat.nclusters,)
273+
274+
# }}}
275+
276+
x = np.concatenate(x)[np.argsort(sources.indices)]
277+
return unflatten(ary, actx.from_numpy(x), actx)
278+
279+
# }}}
280+
281+
282+
# {{{ build_hmatrix_matvec_by_proxy
283+
284+
def build_hmatrix_matvec_by_proxy(
285+
actx: PyOpenCLArrayContext,
286+
places: GeometryCollection,
287+
exprs: Union[sym.Expression, Iterable[sym.Expression]],
288+
input_exprs: Union[sym.Expression, Iterable[sym.Expression]], *,
289+
method: str = "forward",
290+
domains: Optional[Iterable[sym.DOFDescriptorLike]] = None,
291+
context: Optional[Dict[str, Any]] = None,
292+
id_eps: float = 1.0e-8,
293+
294+
# NOTE: these are dev variables and can disappear at any time!
295+
# TODO: plugin in error model to get an estimate for:
296+
# * how many points we want per cluster?
297+
# * how many proxy points we want?
298+
# * how far away should the proxy points be?
299+
# based on id_eps. How many of these should be user tunable?
300+
_tree_kind: Optional[str] = "adaptive-level-restricted",
301+
_max_particles_in_box: Optional[int] = None,
302+
303+
_approx_nproxy: Optional[int] = None,
304+
_proxy_radius_factor: Optional[float] = None,
305+
):
306+
if method not in ("forward", "backard"):
307+
raise ValueError(f"unknown matvec method: '{method}'")
308+
309+
from pytential.linalg.cluster import partition_by_nodes
310+
cluster_index, ctree = partition_by_nodes(
311+
actx, places,
312+
tree_kind=_tree_kind,
313+
max_particles_in_box=_max_particles_in_box)
314+
315+
from pytential.linalg.utils import TargetAndSourceClusterList
316+
tgt_src_index = TargetAndSourceClusterList(
317+
targets=cluster_index, sources=cluster_index)
318+
319+
from pytential.linalg.proxy import QBXProxyGenerator
320+
proxy = QBXProxyGenerator(places,
321+
approx_nproxy=_approx_nproxy,
322+
radius_factor=_proxy_radius_factor)
323+
324+
from pytential.linalg.skeletonization import make_skeletonization_wrangler
325+
wrangler = make_skeletonization_wrangler(
326+
places, exprs, input_exprs,
327+
domains=domains, context=context)
328+
329+
from pytential.linalg.skeletonization import rec_skeletonize_by_proxy
330+
skeletons = rec_skeletonize_by_proxy(
331+
actx, places, ctree, tgt_src_index, exprs, input_exprs,
332+
id_eps=id_eps,
333+
max_particles_in_box=_max_particles_in_box,
334+
_proxy=proxy,
335+
_wrangler=wrangler,
336+
)
337+
338+
if method == "backward":
339+
pass
340+
341+
return ProxyHierarchicalMatrix(
342+
wrangler=wrangler, ctree=ctree, method=method, skeletons=skeletons)
343+
344+
# }}}

0 commit comments

Comments
 (0)