1313# At this point there are no checks of input arguments whatsoever, arguments
1414# are simply forwarded as-is.
1515
16+ import os
17+ import re
1618from importlib import import_module
1719from os import getenv
1820from typing import Any
4244
4345
4446def init (cw = None ):
47+ libidtr = os .path .join (os .path .dirname (__file__ ), "libidtr.so" )
48+ assert os .path .isfile (libidtr ), "libidtr.so not found"
49+
4550 cw = _sharpy_cw if cw is None else cw
46- _init (cw )
51+ _init (cw , libidtr )
4752
4853
4954def to_numpy (a ):
@@ -64,31 +69,42 @@ def to_numpy(a):
6469 f"{ op } = lambda this: ndarray(_csp.EWUnyOp.op(_csp.{ OP } , this._t))"
6570 )
6671
72+
73+ def _validate_device (device ):
74+ if len (device ) == 0 or re .search (
75+ r"^((opencl|level-zero|cuda):)?(host|gpu|cpu|accelerator)(:\d+)?$" ,
76+ device ,
77+ ):
78+ return device
79+ else :
80+ raise ValueError (f"Invalid device string: { device } " )
81+
82+
6783for func in api .api_categories ["Creator" ]:
6884 FUNC = func .upper ()
6985 if func == "full" :
7086 exec (
71- f"{ func } = lambda shape, val, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, val, dtype, device, team))"
87+ f"{ func } = lambda shape, val, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, val, dtype, _validate_device( device) , team))"
7288 )
7389 elif func == "empty" :
7490 exec (
75- f"{ func } = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, None, dtype, device, team))"
91+ f"{ func } = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, None, dtype, _validate_device( device) , team))"
7692 )
7793 elif func == "ones" :
7894 exec (
79- f"{ func } = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 1, dtype, device, team))"
95+ f"{ func } = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 1, dtype, _validate_device( device) , team))"
8096 )
8197 elif func == "zeros" :
8298 exec (
83- f"{ func } = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, device, team))"
99+ f"{ func } = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device( device) , team))"
84100 )
85101 elif func == "arange" :
86102 exec (
87- f"{ func } = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, device, team))"
103+ f"{ func } = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device( device) , team))"
88104 )
89105 elif func == "linspace" :
90106 exec (
91- f"{ func } = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, device, team))"
107+ f"{ func } = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device( device) , team))"
92108 )
93109
94110for func in api .api_categories ["ReduceOp" ]:
@@ -116,10 +132,17 @@ def to_numpy(a):
116132
117133_fb_env = getenv ("SHARPY_FALLBACK" )
118134if _fb_env is not None :
135+ if not _fb_env .isalnum ():
136+ raise ValueError (f"Invalid SHARPY_FALLBACK value '{ _fb_env } '" )
119137
120138 class _fallback :
121139 "Fallback to whatever is provided in SHARPY_FALLBACK"
122- _fb_lib = import_module (_fb_env )
140+ try :
141+ _fb_lib = import_module (_fb_env )
142+ except ModuleNotFoundError :
143+ raise ValueError (
144+ f"Invalid SHARPY_FALLBACK value '{ _fb_env } ': module not found"
145+ )
123146
124147 def __init__ (self , fname : str , mod = None ) -> None :
125148 """get callable with name 'fname' from fallback-lib
0 commit comments