Skip to content

Commit 488dcbd

Browse files
authored
Merge pull request #136 from mrava87/feature-simplexcuda
Feature: added cuda version of Simplex proximal
2 parents 516afb8 + 32eae63 commit 488dcbd

File tree

2 files changed

+134
-7
lines changed

2 files changed

+134
-7
lines changed

pyproximal/proximal/Simplex.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import logging
22
import numpy as np
3+
4+
from pylops.utils.backend import get_array_module, to_cupy_conditional
35
from pyproximal.ProxOperator import _check_tau
46
from pyproximal import ProxOperator
57
from pyproximal.projection import SimplexProj
68

79
try:
810
from numba import jit
911
from ._Simplex_numba import bisect_jit, simplex_jit, fun_jit
12+
from ._Simplex_cuda import bisect_jit_cuda, simplex_jit_cuda, fun_jit_cuda
1013
except ModuleNotFoundError:
1114
jit = None
1215
jit_message = 'Numba not available, reverting to numpy.'
@@ -20,8 +23,8 @@
2023
class _Simplex(ProxOperator):
2124
"""Simplex operator (numpy version)
2225
"""
23-
def __init__(self, n, radius, dims=None, axis=-1, maxiter=100, xtol=1e-8,
24-
call=True):
26+
def __init__(self, n, radius, dims=None, axis=-1,
27+
maxiter=100, xtol=1e-8, call=True):
2528
super().__init__(None, False)
2629
if dims is not None and len(dims) != 2:
2730
raise ValueError('provide only 2 dimensions, or None')
@@ -90,6 +93,7 @@ def __init__(self, n, radius, dims=None, axis=-1,
9093
self.xtol = xtol
9194
self.call = call
9295

96+
@_check_tau
9397
def prox(self, x, tau):
9498
if self.dims is None:
9599
bisect_lower = -1
@@ -113,6 +117,50 @@ def prox(self, x, tau):
113117
return y.ravel()
114118

115119

120+
class _Simplex_cuda(_Simplex):
121+
"""Simplex operator (cuda version)
122+
123+
This implementation is adapted from https://github.com/DIG-Kaust/HPC_Hackathon_DIG.
124+
125+
"""
126+
def __init__(self, n, radius, dims=None, axis=-1,
127+
maxiter=100, ftol=1e-8, xtol=1e-8, call=False,
128+
num_threads_per_blocks=32):
129+
super().__init__(None, False)
130+
if dims is not None and len(dims) != 2:
131+
raise ValueError('provide only 2 dimensions, or None')
132+
self.n = n
133+
# self.coeffs = cuda.to_device(np.ones(self.n if dims is None else dims[axis]))
134+
self.coeffs = np.ones(self.n if dims is None else dims[axis])
135+
self.radius = radius
136+
self.dims = dims
137+
self.axis = axis
138+
self.otheraxis = 1 if axis == 0 else 0
139+
self.maxiter = maxiter
140+
self.ftol = ftol
141+
self.xtol = xtol
142+
self.call = call
143+
self.num_threads_per_blocks = num_threads_per_blocks
144+
145+
@_check_tau
146+
def prox(self, x, tau):
147+
ncp = get_array_module(x)
148+
x = x.reshape(self.dims)
149+
if self.axis == 0:
150+
x = x.T
151+
if type(self.coeffs) != type(x):
152+
self.coeffs = to_cupy_conditional(x, self.coeffs)
153+
154+
y = ncp.empty_like(x)
155+
num_blocks = (x.shape[0] + self.num_threads_per_blocks - 1) // self.num_threads_per_blocks
156+
simplex_jit_cuda[num_blocks, self.num_threads_per_blocks](x, self.coeffs, self.radius,
157+
0, 10000000000, self.maxiter,
158+
self.ftol, self.xtol, y)
159+
if self.axis == 0:
160+
y = y.T
161+
return y.ravel()
162+
163+
116164
def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
117165
ftol=1e-8, xtol=1e-8, call=True, engine='numpy'):
118166
r"""Simplex proximal operator.
@@ -137,18 +185,18 @@ def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
137185
maxiter : :obj:`int`, optional
138186
Maximum number of iterations used by bisection
139187
ftol : :obj:`float`, optional
140-
Function tolerance in bisection (only with ``engine='numba'``)
188+
Function tolerance in bisection (only with ``engine='numba'`` or ``engine='cuda'``)
141189
xtol : :obj:`float`, optional
142190
Solution absolute tolerance in bisection
143191
call : :obj:`bool`, optional
144192
Evalutate call method (``True``) or not (``False``)
145193
engine : :obj:`str`, optional
146-
Engine used for simplex computation (``numpy`` or ``numba``).
194+
Engine used for simplex computation (``numpy``, ``numba``or ``cuda``).
147195
148196
Raises
149197
------
150198
KeyError
151-
If ``engine`` is neither ``numpy`` nor ``numba``
199+
If ``engine`` is neither ``numpy`` nor ``numba`` nor ``cuda``
152200
ValueError
153201
If ``dims`` is provided as a list (or tuple) with more or less than
154202
2 elements
@@ -163,12 +211,15 @@ def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
163211
positive number can be provided.
164212
165213
"""
166-
if not engine in ['numpy', 'numba']:
167-
raise KeyError('engine must be numpy or numba')
214+
if not engine in ['numpy', 'numba', 'cuda']:
215+
raise KeyError('engine must be numpy or numba or cuda')
168216

169217
if engine == 'numba' and jit is not None:
170218
s = _Simplex_numba(n, radius, dims=dims, axis=axis,
171219
maxiter=maxiter, ftol=ftol, xtol=xtol, call=call)
220+
elif engine == 'cuda' and jit is not None:
221+
s = _Simplex_cuda(n, radius, dims=dims, axis=axis,
222+
maxiter=maxiter, ftol=ftol, xtol=xtol, call=call)
172223
else:
173224
if engine == 'numba' and jit is None:
174225
logging.warning(jit_message)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from numba import cuda
2+
3+
4+
@cuda.jit(device=True)
5+
def fun_jit_cuda(mu, x, coeffs, scalar, lower, upper):
6+
"""Bisection function"""
7+
p = 0
8+
for i in range(coeffs.shape[0]):
9+
p += coeffs[i] * min(max(x[i] - mu * coeffs[i], lower), upper)
10+
return p - scalar
11+
12+
13+
@cuda.jit(device=True)
14+
def bisect_jit_cuda(x, coeffs, scalar, lower, upper, bisect_lower, bisect_upper,
15+
maxiter, ftol, xtol):
16+
"""Bisection method (See _Simplex_numba for details).
17+
18+
"""
19+
a, b = bisect_lower, bisect_upper
20+
fa = fun_jit_cuda(a, x, coeffs, scalar, lower, upper)
21+
for iiter in range(maxiter):
22+
c = (a + b) / 2.
23+
if (b - a) / 2 < xtol:
24+
return c
25+
fc = fun_jit_cuda(c, x, coeffs, scalar, lower, upper)
26+
if abs(fc) < ftol:
27+
return c
28+
if fc / abs(fc) == fa / abs(fa):
29+
a = c
30+
fa = fc
31+
else:
32+
b = c
33+
return c
34+
35+
36+
@cuda.jit
37+
def simplex_jit_cuda(x, coeffs, scalar, lower, upper, maxiter, ftol, xtol, y):
38+
"""Simplex proximal
39+
40+
Parameters
41+
----------
42+
x : :obj:`np.ndarray`
43+
Input vector
44+
coeffs : :obj:`np.ndarray`
45+
Vector of coefficients used in the definition of the hyperplane
46+
scalar : :obj:`float`
47+
Scalar used in the definition of the hyperplane
48+
lower : :obj:`float` or :obj:`np.ndarray`, optional
49+
Lower bound of Box
50+
upper : :obj:`float` or :obj:`np.ndarray`, optional
51+
Upper bound of Box
52+
maxiter : :obj:`int`, optional
53+
Maximum number of iterations
54+
ftol : :obj:`float`, optional
55+
Function tolerance
56+
xtol : :obj:`float`, optional
57+
Solution absolute tolerance
58+
y : :obj:`np.ndarray`
59+
Output vector
60+
61+
"""
62+
i = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
63+
64+
if i < x.shape[0]:
65+
bisect_lower = -1
66+
while fun_jit_cuda(bisect_lower, x[i], coeffs, scalar, lower, upper) < 0:
67+
bisect_lower *= 2
68+
bisect_upper = 1
69+
while fun_jit_cuda(bisect_upper, x[i], coeffs, scalar, lower, upper) > 0:
70+
bisect_upper *= 2
71+
72+
c = bisect_jit_cuda(x[i], coeffs, scalar, lower, upper,
73+
bisect_lower, bisect_upper, maxiter, ftol, xtol)
74+
75+
for j in range(coeffs.shape[0]):
76+
y[i][j] = min(max(x[i][j] - c * coeffs[j], lower), upper)

0 commit comments

Comments
 (0)