11import logging
22import numpy as np
3+
4+ from pylops .utils .backend import get_array_module , to_cupy_conditional
35from pyproximal .ProxOperator import _check_tau
46from pyproximal import ProxOperator
57from pyproximal .projection import SimplexProj
68
79try :
810 from numba import jit
911 from ._Simplex_numba import bisect_jit , simplex_jit , fun_jit
12+ from ._Simplex_cuda import bisect_jit_cuda , simplex_jit_cuda , fun_jit_cuda
1013except ModuleNotFoundError :
1114 jit = None
1215 jit_message = 'Numba not available, reverting to numpy.'
2023class _Simplex (ProxOperator ):
2124 """Simplex operator (numpy version)
2225 """
23- def __init__ (self , n , radius , dims = None , axis = - 1 , maxiter = 100 , xtol = 1e-8 ,
24- call = True ):
26+ def __init__ (self , n , radius , dims = None , axis = - 1 ,
27+ maxiter = 100 , xtol = 1e-8 , call = True ):
2528 super ().__init__ (None , False )
2629 if dims is not None and len (dims ) != 2 :
2730 raise ValueError ('provide only 2 dimensions, or None' )
@@ -90,6 +93,7 @@ def __init__(self, n, radius, dims=None, axis=-1,
9093 self .xtol = xtol
9194 self .call = call
9295
96+ @_check_tau
9397 def prox (self , x , tau ):
9498 if self .dims is None :
9599 bisect_lower = - 1
@@ -113,6 +117,50 @@ def prox(self, x, tau):
113117 return y .ravel ()
114118
115119
120+ class _Simplex_cuda (_Simplex ):
121+ """Simplex operator (cuda version)
122+
123+ This implementation is adapted from https://github.com/DIG-Kaust/HPC_Hackathon_DIG.
124+
125+ """
126+ def __init__ (self , n , radius , dims = None , axis = - 1 ,
127+ maxiter = 100 , ftol = 1e-8 , xtol = 1e-8 , call = False ,
128+ num_threads_per_blocks = 32 ):
129+ super ().__init__ (None , False )
130+ if dims is not None and len (dims ) != 2 :
131+ raise ValueError ('provide only 2 dimensions, or None' )
132+ self .n = n
133+ # self.coeffs = cuda.to_device(np.ones(self.n if dims is None else dims[axis]))
134+ self .coeffs = np .ones (self .n if dims is None else dims [axis ])
135+ self .radius = radius
136+ self .dims = dims
137+ self .axis = axis
138+ self .otheraxis = 1 if axis == 0 else 0
139+ self .maxiter = maxiter
140+ self .ftol = ftol
141+ self .xtol = xtol
142+ self .call = call
143+ self .num_threads_per_blocks = num_threads_per_blocks
144+
145+ @_check_tau
146+ def prox (self , x , tau ):
147+ ncp = get_array_module (x )
148+ x = x .reshape (self .dims )
149+ if self .axis == 0 :
150+ x = x .T
151+ if type (self .coeffs ) != type (x ):
152+ self .coeffs = to_cupy_conditional (x , self .coeffs )
153+
154+ y = ncp .empty_like (x )
155+ num_blocks = (x .shape [0 ] + self .num_threads_per_blocks - 1 ) // self .num_threads_per_blocks
156+ simplex_jit_cuda [num_blocks , self .num_threads_per_blocks ](x , self .coeffs , self .radius ,
157+ 0 , 10000000000 , self .maxiter ,
158+ self .ftol , self .xtol , y )
159+ if self .axis == 0 :
160+ y = y .T
161+ return y .ravel ()
162+
163+
116164def Simplex (n , radius , dims = None , axis = - 1 , maxiter = 100 ,
117165 ftol = 1e-8 , xtol = 1e-8 , call = True , engine = 'numpy' ):
118166 r"""Simplex proximal operator.
@@ -137,18 +185,18 @@ def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
137185 maxiter : :obj:`int`, optional
138186 Maximum number of iterations used by bisection
139187 ftol : :obj:`float`, optional
140- Function tolerance in bisection (only with ``engine='numba'``)
188+ Function tolerance in bisection (only with ``engine='numba'`` or ``engine='cuda'`` )
141189 xtol : :obj:`float`, optional
142190 Solution absolute tolerance in bisection
143191 call : :obj:`bool`, optional
144192 Evalutate call method (``True``) or not (``False``)
145193 engine : :obj:`str`, optional
146- Engine used for simplex computation (``numpy`` or ``numba ``).
194+ Engine used for simplex computation (``numpy``, ``numba`` or ``cuda ``).
147195
148196 Raises
149197 ------
150198 KeyError
151- If ``engine`` is neither ``numpy`` nor ``numba``
199+ If ``engine`` is neither ``numpy`` nor ``numba`` nor ``cuda``
152200 ValueError
153201 If ``dims`` is provided as a list (or tuple) with more or less than
154202 2 elements
@@ -163,12 +211,15 @@ def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
163211 positive number can be provided.
164212
165213 """
166- if not engine in ['numpy' , 'numba' ]:
167- raise KeyError ('engine must be numpy or numba' )
214+ if not engine in ['numpy' , 'numba' , 'cuda' ]:
215+ raise KeyError ('engine must be numpy or numba or cuda ' )
168216
169217 if engine == 'numba' and jit is not None :
170218 s = _Simplex_numba (n , radius , dims = dims , axis = axis ,
171219 maxiter = maxiter , ftol = ftol , xtol = xtol , call = call )
220+ elif engine == 'cuda' and jit is not None :
221+ s = _Simplex_cuda (n , radius , dims = dims , axis = axis ,
222+ maxiter = maxiter , ftol = ftol , xtol = xtol , call = call )
172223 else :
173224 if engine == 'numba' and jit is None :
174225 logging .warning (jit_message )
0 commit comments