diff --git a/pyccel/ast/cudaext.py b/pyccel/ast/cudaext.py index 3f04c272ee..d4efab3716 100644 --- a/pyccel/ast/cudaext.py +++ b/pyccel/ast/cudaext.py @@ -1,9 +1,9 @@ from .basic import PyccelAstNode from .builtins import (PythonTuple,PythonList) -from .core import Module, PyccelFunctionDef +from .core import Module, PyccelFunctionDef, Import -from .datatypes import NativeInteger +from .datatypes import NativeInteger, NativeVoid, NativeFloat, TimeVal from .internals import PyccelInternalFunction, get_final_precision @@ -26,7 +26,8 @@ 'CudaMemCopy', 'CudaNewArray', 'CudaSynchronize', - 'CudaThreadIdx' + 'CudaThreadIdx', + 'CudaTime' ) #============================================================================== @@ -109,6 +110,34 @@ def arg(self): def memory_location(self): return self._memory_location +class CudaSharedArray(CudaNewArray): + """ + Represents a call to cuda.shared.array for code generation. + + arg : list, tuple, PythonList + + """ + + __slots__ = ('_dtype','_precision','_shape','_rank','_order', '_memory_location') + name = 'array' + + def __init__(self, shape, dtype, order='C'): + + # Convert shape to PythonTuple + self._shape = process_shape(False, shape) + + # Verify dtype and get precision + self._dtype, self._precision = process_dtype(dtype) + + self._rank = len(self._shape) + self._order = self._order = NumpyNewArray._process_order(self._rank, order) + self._memory_location = 'shared' + super().__init__() + + @property + def memory_location(self): + return self._memory_location + class CudaSynchronize(PyccelInternalFunction): "Represents a call to Cuda.deviceSynchronize for code generation." @@ -116,7 +145,20 @@ class CudaSynchronize(PyccelInternalFunction): _attribute_nodes = () _shape = None _rank = 0 - _dtype = NativeInteger() + _dtype = NativeVoid() + _precision = None + _order = None + def __init__(self): + super().__init__() + +class CudaSyncthreads(PyccelInternalFunction): + "Represents a call to __syncthreads for code generation." + + __slots__ = () + _attribute_nodes = () + _shape = None + _rank = 0 + _dtype = NativeVoid() _precision = None _order = None def __init__(self): @@ -156,6 +198,47 @@ def __init__(self, dim=None): def dim(self): return self._dim +class CudaTime(PyccelInternalFunction): + __slots__ = () + _attribute_nodes = () + _shape = None + _rank = 0 + _dtype = TimeVal() + _precision = 0 + _order = None + def __init__(self): + super().__init__() + +class CudaTimeDiff(PyccelAstNode): + """ + Represents a General Class For Cuda internal Variables Used To locate Thread In the GPU architecture" + + Parameters + ---------- + dim : NativeInteger + Represent the dimension where we want to locate our thread. + + """ + __slots__ = ('_start','_end', '_dtype', '_precision') + _attribute_nodes = ('_start','_end',) + _shape = None + _rank = 0 + _order = None + + def __init__(self, start=None, end=None): + #... + self._start = start + self._end = end + self._dtype = NativeFloat() + self._precision = 8 + super().__init__() + + @property + def start(self): + return self._start + @property + def end(self): + return self._end class CudaCopy(CudaNewArray): """ @@ -251,17 +334,19 @@ def __new__(cls, dim=0): return expr[0] return PythonTuple(*expr) - - cuda_funcs = { 'array' : PyccelFunctionDef('array' , CudaArray), 'copy' : PyccelFunctionDef('copy' , CudaCopy), 'synchronize' : PyccelFunctionDef('synchronize' , CudaSynchronize), + 'syncthreads' : PyccelFunctionDef('syncthreads' , CudaSyncthreads), 'threadIdx' : PyccelFunctionDef('threadIdx' , CudaThreadIdx), 'blockDim' : PyccelFunctionDef('blockDim' , CudaBlockDim), 'blockIdx' : PyccelFunctionDef('blockIdx' , CudaBlockIdx), 'gridDim' : PyccelFunctionDef('gridDim' , CudaGridDim), - 'grid' : PyccelFunctionDef('grid' , CudaGrid) + 'grid' : PyccelFunctionDef('grid' , CudaGrid), + 'time' : PyccelFunctionDef('time' , CudaTime), + 'timediff' : PyccelFunctionDef('timediff' , CudaTimeDiff), + } cuda_Internal_Var = { @@ -271,6 +356,16 @@ def __new__(cls, dim=0): 'CudaGridDim' : 'gridDim' } +# cuda_sharedmemory = { +# 'array' : PyccelFunctionDef('array' , CudaSharedArray), +# } + +cuda_sharedmemory = Module('shared', (), + [ PyccelFunctionDef('array' , CudaSharedArray)]) + cuda_mod = Module('cuda', variables = [], - funcs = cuda_funcs.values()) \ No newline at end of file + funcs = cuda_funcs.values(), + imports = [ + Import('shared', cuda_sharedmemory), + ]) \ No newline at end of file diff --git a/pyccel/ast/datatypes.py b/pyccel/ast/datatypes.py index 6d17610a8a..fa6cc05f3c 100644 --- a/pyccel/ast/datatypes.py +++ b/pyccel/ast/datatypes.py @@ -377,3 +377,9 @@ def str_dtype(dtype): return 'bool' else: raise TypeError('Unknown datatype {0}'.format(str(dtype))) + + +class TimeVal(DataType): + """Class representing timeval datatype""" + __slots__ = () + _name = 'timeval' \ No newline at end of file diff --git a/pyccel/ast/operators.py b/pyccel/ast/operators.py index d863a85d00..e07ae77c2d 100644 --- a/pyccel/ast/operators.py +++ b/pyccel/ast/operators.py @@ -18,6 +18,7 @@ from .datatypes import (NativeBool, NativeInteger, NativeFloat, NativeComplex, NativeString, NativeNumeric) +from .datatypes import TimeVal from .internals import max_precision @@ -393,6 +394,7 @@ def _calculate_dtype(cls, *args): floats = [a for a in args if a.dtype is NativeFloat()] complexes = [a for a in args if a.dtype is NativeComplex()] strs = [a for a in args if a.dtype is NativeString()] + time = [a for a in args if a.dtype is TimeVal()] if strs: assert len(integers + floats + complexes) == 0 @@ -403,6 +405,8 @@ def _calculate_dtype(cls, *args): return cls._handle_float_type(args) elif integers: return cls._handle_integer_type(args) + elif time: + return time else: raise TypeError('cannot determine the type of {}'.format(args)) diff --git a/pyccel/ast/variable.py b/pyccel/ast/variable.py index 550a51e3c1..3b161e5c60 100644 --- a/pyccel/ast/variable.py +++ b/pyccel/ast/variable.py @@ -166,8 +166,8 @@ def __init__( raise ValueError("memory_handling must be 'heap', 'stack' or 'alias'") self._memory_handling = memory_handling - if memory_location not in ('host', 'device', 'managed'): - raise ValueError("memory_location must be 'host', 'device' or 'managed'") + if memory_location not in ('host', 'device', 'managed', 'shared'): + raise ValueError("memory_location must be 'host', 'device' , 'shared' or 'managed'") self._memory_location = memory_location if not isinstance(is_const, bool): diff --git a/pyccel/codegen/printing/ccudacode.py b/pyccel/codegen/printing/ccudacode.py index bfefbdc3dd..48b46edc49 100644 --- a/pyccel/codegen/printing/ccudacode.py +++ b/pyccel/codegen/printing/ccudacode.py @@ -6,23 +6,30 @@ # pylint: disable=missing-function-docstring +from functools import reduce from pyccel.ast.builtins import PythonTuple from pyccel.ast.core import (FunctionCall, Deallocate, FunctionAddress, FunctionDefArgument, Assign, Import, - AliasAssign, Module) + AliasAssign, Module, Declare, AsName) +from pyccel.ast.datatypes import NativeInteger, NativeComplex, NativeBool, TimeVal, default_precision from pyccel.ast.datatypes import NativeTuple, datatype -from pyccel.ast.literals import LiteralTrue, Literal, Nil +from pyccel.ast.literals import LiteralTrue, LiteralString, Literal, Nil +from pyccel.ast.c_concepts import ObjectAddress, CMacro, CStringExpression from pyccel.ast.numpyext import NumpyFull, NumpyArray, NumpyArange from pyccel.ast.cupyext import CupyFull, CupyArray, CupyArange -from pyccel.ast.cudaext import CudaCopy, cuda_Internal_Var, CudaArray +from pyccel.ast.cudaext import CudaCopy, cuda_Internal_Var, CudaArray, CudaSharedArray, CudaTime -from pyccel.ast.variable import Variable +from pyccel.ast.operators import PyccelMul, PyccelUnarySub +from pyccel.ast.variable import Variable, PyccelArraySize +from pyccel.ast.variable import InhomogeneousTupleVariable, DottedName + +from pyccel.ast.internals import Slice, get_final_precision from pyccel.ast.c_concepts import ObjectAddress from pyccel.codegen.printing.ccode import CCodePrinter @@ -183,6 +190,7 @@ ('int',8) : 'int64_t', ('int',2) : 'int16_t', ('int',1) : 'int8_t', + ('timeval', 0): 'struct timeval', ('bool',4) : 'bool'} ndarray_type_registry = { @@ -204,15 +212,18 @@ 'string', 'ndarrays', 'cuda_ndarrays', + 'ho_cuda_ndarrays', 'math', 'complex', 'stdint', 'pyc_math_c', 'stdio', 'stdbool', + 'sys/time', 'assert']} class CCudaCodePrinter(CCodePrinter): + print("---------") """A printer to convert python expressions to strings of ccuda code""" printmethod = "_ccudacode" language = "ccuda" @@ -287,14 +298,149 @@ def get_var_arg(arg, var): if self._additional_args : self._additional_args.pop() + #TODO: need to check if "extern C" is necessary. extern_word = 'extern "C"' - cuda_deco = "__global__" if 'kernel' in expr.decorators else '' + # extern_word = '' + + cuda_deco = '' + if 'kernel' in expr.decorators: + cuda_deco = "__global__" + elif 'device' in expr.decorators: + cuda_deco = "__device__" if isinstance(expr, FunctionAddress): return f'{extern_word} {ret_type} (*{name})({arg_code})' else: return f'{extern_word} {cuda_deco} {ret_type} {name}({arg_code})' + def _print_Import(self, expr): + if expr.ignore: + return '' + if isinstance(expr.source, AsName): + source = expr.source.name + else: + source = expr.source + if isinstance(source, DottedName): + source = source.name[-1] + else: + source = self._print(source) + + # Get with a default value is not used here as it is + # slower and on most occasions the import will not be in the + # dictionary + if source in import_dict: # pylint: disable=consider-using-get + source = import_dict[source] + + if source is None: + return '' + if expr.source in c_library_headers: + return '#include <{0}.h>\n'.format(source) + else: + if len(source) > 3 and source[:2] == 'ho': + # self._additional_imports.pop(source) + return f'#define HO_CUDA_PYCCEL\n#include "{source}.h"\n' + + return f'#include "{source}.h"\n' + + def _print_Declare(self, expr): + if isinstance(expr.variable, InhomogeneousTupleVariable): + return ''.join(self._print_Declare(Declare(v.dtype,v,intent=expr.intent, static=expr.static)) for v in expr.variable) + + declaration_type = self.get_declare_type(expr.variable) + variable = self._print(expr.variable.name) + + if expr.variable.memory_location == 'shared': + preface, init = self._init_shared_array(expr.variable) + elif expr.variable.is_stack_array: + preface, init = self._init_stack_array(expr.variable,) + elif declaration_type == 't_ndarray' and not self._in_header: + preface = '' + init = ' = {.shape = NULL}' + else: + preface = '' + init = '' + + declaration = f'{declaration_type} {variable}{init};\n' + + return preface + declaration + + def _init_shared_array(self, expr): + """ return a string which handles the assignment of a shared ndarray + + Parameters + ---------- + expr : PyccelAstNode + The Assign Node used to get the lhs and rhs + Returns + ------- + buffer_array : str + String initialising the shared (C) array which stores the data + array_init : str + String containing the rhs of the initialization of a stack array + """ + var = expr + dtype_str = self._print(var.dtype) + dtype = self.find_in_dtype_registry(dtype_str, var.precision) + np_dtype = self.find_in_ndarray_type_registry(dtype_str, var.precision) + shape = ", ".join(self._print(i) for i in var.alloc_shape) + tot_shape = self._print(reduce( + lambda x,y: PyccelMul(x,y,simplify=True), var.alloc_shape)) + declare_dtype = self.find_in_dtype_registry('int', 8) + + dummy_array_name = self.scope.get_new_name('array_dummy') + is_shared = '__shared__' if expr.memory_location == 'shared' else '' + buffer_array = f'{is_shared} {dtype} {dummy_array_name}[{tot_shape}];\n' + shape_init = "({declare_dtype}[]){{{shape}}}".format(declare_dtype=declare_dtype, shape=shape) + strides_init = "({declare_dtype}[{length}]){{0}}".format(declare_dtype=declare_dtype, length=len(var.shape)) + array_init = ' = (t_ndarray){{\n.{0}={1},\n .nd={4},\n .shape={2},\n' + array_init += '.strides={3},\n .type={0},\n .is_view={5}\n}};\n' + array_init = array_init.format(np_dtype, dummy_array_name, + shape_init, strides_init, len(var.shape), 'false') + # TODO: call this only one time per block (need to check threadIdx.xyz == 0 then and only then run the code) + array_init += 'shared_array_init(&{})'.format(self._print(var)) + self.add_import(c_imports['ho_cuda_ndarrays']) + return buffer_array, array_init + + + def _init_stack_array(self, expr): + """ return a string which handles the assignment of a stack ndarray + + Parameters + ---------- + expr : PyccelAstNode + The Assign Node used to get the lhs and rhs + Return + ------- + buffer_array : str + String initialising the stack (C) array which stores the data + array_init : str + String containing the rhs of the initialization of a stack array + """ + var = expr + dtype_str = self._print(var.dtype) + dtype = self.find_in_dtype_registry(dtype_str, var.precision) + np_dtype = self.find_in_ndarray_type_registry(dtype_str, var.precision) + shape = ", ".join(self._print(i) for i in var.alloc_shape) + tot_shape = self._print(reduce( + lambda x,y: PyccelMul(x,y,simplify=True), var.alloc_shape)) + declare_dtype = self.find_in_dtype_registry('int', 8) + + dummy_array_name = self.scope.get_new_name('array_dummy') + buffer_array = "{dtype} {name}[{size}];\n".format( + dtype = dtype, + name = dummy_array_name, + size = tot_shape) + shape_init = "({declare_dtype}[]){{{shape}}}".format(declare_dtype=declare_dtype, shape=shape) + strides_init = "({declare_dtype}[{length}]){{0}}".format(declare_dtype=declare_dtype, length=len(var.shape)) + array_init = ' = (t_ndarray){{\n.{0}={1},\n .shape={2},\n .strides={3},\n ' + array_init += '.nd={4},\n .type={0},\n .is_view={5}\n}};\n' + array_init = array_init.format(np_dtype, dummy_array_name, + shape_init, strides_init, len(var.shape), 'false') + array_init += 'stack_array_init(&{})'.format(self._print(var)) + self.add_import(c_imports['ndarrays']) + return buffer_array, array_init + + def _print_Allocate(self, expr): free_code = '' #free the array if its already allocated and checking if its not null if the status is unknown @@ -317,7 +463,7 @@ def _print_Allocate(self, expr): memory_location = 'allocateMemoryOn' + str(memory_location).capitalize() else: memory_location = 'managedMemory' - alloc_code = f"{expr.variable} = \ + alloc_code = f"{self._print(expr.variable)} = \ cuda_array_create({len(expr.shape)}, {tmp_shape}, {dtype}, {is_view}, {memory_location});" return f"{free_code}\n{shape_Assign}\n{alloc_code}\n" @@ -331,6 +477,52 @@ def _print_Deallocate(self, expr): else: return f"cuda_free({var_code});\n" + def _print_IndexedElement(self, expr): + base = expr.base + inds = list(expr.indices) + base_shape = base.shape + allow_negative_indexes = True if isinstance(base, PythonTuple) else base.allows_negative_indexes + for i, ind in enumerate(inds): + if isinstance(ind, PyccelUnarySub) and isinstance(ind.args[0], LiteralInteger): + inds[i] = PyccelMinus(base_shape[i], ind.args[0], simplify = True) + else: + #indices of indexedElement of len==1 shouldn't be a tuple + if isinstance(ind, tuple) and len(ind) == 1: + inds[i].args = ind[0] + if allow_negative_indexes and \ + not isinstance(ind, LiteralInteger) and not isinstance(ind, Slice): + inds[i] = IfTernaryOperator(PyccelLt(ind, LiteralInteger(0)), + PyccelAdd(base_shape[i], ind, simplify = True), ind) + #set dtype to the C struct types + dtype = self._print(expr.dtype) + dtype = self.find_in_ndarray_type_registry(dtype, expr.precision) + base_name = self._print(base) + if getattr(base, 'is_ndarray', False) or isinstance(base, HomogeneousTupleVariable): + if expr.rank > 0: + #managing the Slice input + for i , ind in enumerate(inds): + if isinstance(ind, Slice): + inds[i] = self._new_slice_with_processed_arguments(ind, PyccelArraySize(base, i), + allow_negative_indexes) + else: + inds[i] = Slice(ind, PyccelAdd(ind, LiteralInteger(1), simplify = True), LiteralInteger(1), + Slice.Element) + inds = [self._print(i) for i in inds] + return "cuda_array_slicing(%s, %s, (t_slice []){%s})" % (base_name, expr.rank, ", ".join(inds)) + inds = [self._cast_to(i, NativeInteger(), 8).format(self._print(i)) for i in inds] + else: + raise NotImplementedError(expr) + return "GET_ELEMENT(%s, %s, %s)" % (base_name, dtype, ", ".join(inds)) + + + def _print_Slice(self, expr): + start = self._print(expr.start) + stop = self._print(expr.stop) + step = self._print(expr.step) + slice_type = 'RANGE' if expr.slice_type == Slice.Range else 'ELEMENT' + return f'cuda_new_slice({start}, {stop}, {step}, {slice_type})' + + def _print_KernelCall(self, expr): func = expr.funcdef if func.is_inline: @@ -358,7 +550,7 @@ def _print_KernelCall(self, expr): args = ', '.join(['{}'.format(self._print(a)) for a in args]) # TODO: need to raise error in semantic if we have result , kernel can't return if not func.results: - return '{}<<<{},{}>>>({});\n'.format(func.name, expr.numBlocks, expr.tpblock,args) + return '{}<<>>({});\n'.format(func.name, expr.numBlocks, expr.tpblock,args) def _print_Assign(self, expr): prefix_code = '' @@ -382,7 +574,8 @@ def _print_Assign(self, expr): self._temporary_args = [ObjectAddress(a) for a in lhs] return prefix_code+'{};\n'.format(self._print(rhs)) # Inhomogenous tuples are unravelled and therefore do not exist in the c printer - + if isinstance(rhs, CudaTime): + return prefix_code+self.cuda_time(expr) if isinstance(rhs, (CupyFull)): return prefix_code+self.cuda_arrayFill(expr) if isinstance(rhs, CupyArange): @@ -397,6 +590,8 @@ def _print_Assign(self, expr): return prefix_code+self.fill_NumpyArange(rhs, lhs) if isinstance(rhs, CudaCopy): return prefix_code+self.cudaCopy(lhs, rhs) + if isinstance(rhs, CudaSharedArray): + return '\n' lhs = self._print(expr.lhs) rhs = self._print(expr.rhs) return prefix_code+'{} = {};\n'.format(lhs, rhs) @@ -422,9 +617,9 @@ def arrayFill(self, expr): if rhs.fill_value is not None: if isinstance(rhs.fill_value, Literal): - code_init += 'cuda_array_fill_{0}(({1}){2}, {3});\n'.format(dtype, declare_dtype, self._print(rhs.fill_value), self._print(lhs)) + code_init += 'array_fill_{0}(({1}){2}, {3});\n'.format(dtype, declare_dtype, self._print(rhs.fill_value), self._print(lhs)) else: - code_init += 'cuda_array_fill_{0}({1}, {2});\n'.format(dtype, self._print(rhs.fill_value), self._print(lhs)) + code_init += 'array_fill_{0}({1}, {2});\n'.format(dtype, self._print(rhs.fill_value), self._print(lhs)) return code_init def cuda_Arange(self, expr): @@ -514,14 +709,59 @@ def copy_CudaArray_Data(self, expr): return '%s%s\n' % (dummy_array, cpy_data) def _print_CudaSynchronize(self, expr): - return 'cudaDeviceSynchronize()' + return 'cudaDeviceSynchronize();\n' + + def _print_CudaSyncthreads(self, expr): + return '__syncthreads();\n' + + def _print_CudaSharedArray(self, expr): + return 'TODO' def _print_CudaInternalVar(self, expr): var_name = type(expr).__name__ var_name = cuda_Internal_Var[var_name] dim_c = ('x', 'y', 'z')[expr.dim] return '{}.{}'.format(var_name, dim_c) - + + def _print_TimeVal(self, expr): + self.add_import(c_imports['sys/time']) + return 'timeval' + + def cuda_time(self, expr): + self.add_import(c_imports['sys/time']) + get_time = f'gettimeofday(&{self._print(expr.lhs)}, NULL);\n' + return get_time + + def _print_CudaTimeDiff(self, expr): + self.add_import(c_imports['sys/time']) + start = self._print(expr.start) + end = self._print(expr.end) + return f'({end}.tv_sec - {start}.tv_sec) + ({end}.tv_usec - {start}.tv_usec) / 1e6;' + # return f'((({self._print(expr.end)} - {self._print(expr.start)}) * 1000) / CLOCKS_PER_SEC)' + + def find_in_dtype_registry(self, dtype, prec): + if prec == -1: + prec = default_precision[dtype] + try : + return dtype_registry[(dtype, prec)] + except KeyError: + errors.report(PYCCEL_RESTRICTION_TODO, + symbol = "{}[kind = {}]".format(dtype, prec), + severity='fatal') + + def get_print_format_and_arg(self, var): + try: + arg_format = type_to_format[(self._print(var.dtype), get_final_precision(var))] + except KeyError: + errors.report("{} type is not supported currently".format(var.dtype), severity='fatal') + if var.dtype is NativeComplex(): + arg = '{}, {}'.format(self._print(NumpyReal(var)), self._print(NumpyImag(var))) + elif var.dtype is NativeBool(): + arg = '{} ? "True" : "False"'.format(self._print(var)) + else: + arg = self._print(var) + return arg_format, arg + def cudaCopy(self, lhs, rhs): from_location = 'Host' to_location = 'Host' @@ -530,10 +770,11 @@ def cudaCopy(self, lhs, rhs): if rhs.memory_location in ('device', 'managed'): to_location = 'Device' transfer_type = 'cudaMemcpy{0}To{1}'.format(from_location, to_location) + var = self._print(lhs) if isinstance(rhs.is_async, LiteralTrue): - cpy_data = "cudaMemcpyAsync({0}.raw_data, {1}.raw_data, {0}.buffer_size, {2}, 0);".format(lhs, rhs.arg, transfer_type) + cpy_data = "cudaMemcpyAsync({0}.raw_data, {1}.raw_data, {0}.buffer_size, {2}, 0);".format(var, rhs.arg, transfer_type) else: - cpy_data = "cudaMemcpy({0}.raw_data, {1}.raw_data, {0}.buffer_size, {2});".format(lhs, rhs.arg, transfer_type) + cpy_data = "cudaMemcpy({0}.raw_data, {1}.raw_data, {0}.buffer_size, {2});".format(var, rhs.arg, transfer_type) return '%s\n' % (cpy_data) def ccudacode(expr, filename, assign_to=None, **settings): diff --git a/pyccel/decorators.py b/pyccel/decorators.py index 2d263c4d9e..9642d46a3c 100644 --- a/pyccel/decorators.py +++ b/pyccel/decorators.py @@ -22,6 +22,7 @@ 'template', 'types', 'kernel', + 'device', ) def lambdify(f): @@ -121,3 +122,14 @@ def kernel(f): from numpy import array return array([[f]]) +def device(f): + """ + This decorator is used to mark a Python function as a GPU device function. + + Parameters + ---------- + f : Function + The function to be marked as a device. + + """ + return f diff --git a/pyccel/parser/semantic.py b/pyccel/parser/semantic.py index 18818c8ade..3d26331181 100644 --- a/pyccel/parser/semantic.py +++ b/pyccel/parser/semantic.py @@ -94,7 +94,7 @@ from pyccel.ast.numpyext import DtypePrecisionToCastFunction from pyccel.ast.cupyext import CupyNewArray -from pyccel.ast.cudaext import CudaNewArray, CudaThreadIdx, CudaBlockDim, CudaBlockIdx, CudaGridDim +from pyccel.ast.cudaext import CudaNewArray, CudaThreadIdx, CudaBlockDim, CudaBlockIdx, CudaGridDim, CudaSharedArray from pyccel.ast.omp import (OMP_For_Loop, OMP_Simd_Construct, OMP_Distribute_Construct, OMP_TaskLoop_Construct, OMP_Sections_Construct, Omp_End_Clause, @@ -543,19 +543,20 @@ def _infer_type(self, expr, **settings): d_var['memory_handling'] = 'heap' return d_var - elif isinstance(expr, NumpyNewArray): + elif isinstance(expr, CupyNewArray): d_var['datatype' ] = expr.dtype d_var['memory_handling'] = 'heap' if expr.rank > 0 else 'stack' + d_var['memory_location'] = expr.memory_location d_var['shape' ] = expr.shape d_var['rank' ] = expr.rank d_var['order' ] = expr.order d_var['precision' ] = expr.precision - d_var['cls_base' ] = NumpyArrayClass + d_var['cls_base' ] = CudaArrayClass return d_var - elif isinstance(expr, CupyNewArray): + elif isinstance(expr, CudaNewArray): d_var['datatype' ] = expr.dtype - d_var['memory_handling'] = 'heap' if expr.rank > 0 else 'stack' + d_var['memory_handling'] = 'heap' if (expr.rank > 0 and not isinstance(expr, CudaSharedArray)) else 'stack' d_var['memory_location'] = expr.memory_location d_var['shape' ] = expr.shape d_var['rank' ] = expr.rank @@ -564,17 +565,17 @@ def _infer_type(self, expr, **settings): d_var['cls_base' ] = CudaArrayClass return d_var - elif isinstance(expr, CudaNewArray): + elif isinstance(expr, NumpyNewArray): d_var['datatype' ] = expr.dtype d_var['memory_handling'] = 'heap' if expr.rank > 0 else 'stack' - d_var['memory_location'] = expr.memory_location d_var['shape' ] = expr.shape d_var['rank' ] = expr.rank d_var['order' ] = expr.order d_var['precision' ] = expr.precision - d_var['cls_base' ] = CudaArrayClass + d_var['cls_base' ] = NumpyArrayClass return d_var + elif isinstance(expr, NumpyTranspose): var = expr.internal_var @@ -590,7 +591,6 @@ def _infer_type(self, expr, **settings): return d_var elif isinstance(expr, PyccelAstNode): - d_var['datatype' ] = expr.dtype d_var['memory_handling'] = 'heap' if expr.rank > 0 else 'stack' d_var['shape' ] = expr.shape @@ -1015,7 +1015,7 @@ def _handle_kernel(self, expr, func, args, **settings): symbol = expr, severity='fatal') # TODO : type check the NUMBER OF BLOCKS 'numBlocks' and threads per block 'tpblock' - if not isinstance(expr.numBlocks, LiteralInteger): + if not isinstance(expr.numBlocks, (LiteralInteger, PythonTuple)): # expr.numBlocks could be invalid type, or PyccelSymbol if isinstance(expr.numBlocks, PyccelSymbol): numBlocks = self.get_variable(expr.numBlocks) @@ -1027,7 +1027,7 @@ def _handle_kernel(self, expr, func, args, **settings): errors.report(INVALID_KERNEL_CALL_BP_GRID, symbol = expr, severity='error') - if not isinstance(expr.tpblock, LiteralInteger): + if not isinstance(expr.tpblock, (LiteralInteger, PythonTuple)): # expr.tpblock could be invalid type, or PyccelSymbol if isinstance(expr.tpblock, PyccelSymbol): tpblock = self.get_variable(expr.tpblock) @@ -1046,7 +1046,7 @@ def _handle_kernel(self, expr, func, args, **settings): errors.report("Too few arguments passed in function call", symbol = expr, severity='error') - elif isinstance(a.value, Variable) and a.value.on_stack: + elif isinstance(a.value, Variable) and a.value.rank != 0 and a.value.on_stack: errors.report("A variable allocated on the stack can't be passed to a Kernel function", symbol = expr, severity='error') @@ -1289,7 +1289,7 @@ def _assign_lhs_variable(self, lhs, d_var, rhs, new_expressions, is_augassign,ar # ... # We cannot allow the definition of a stack array in a loop - if lhs.is_stack_array and self.scope.is_loop: + if lhs.is_stack_array and lhs.memory_location != 'shared' and self.scope.is_loop: errors.report(STACK_ARRAY_DEFINITION_IN_LOOP, symbol=name, severity='error', bounding_box=(self._current_fst_node.lineno, diff --git a/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.cu b/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.cu index 2563362433..22b17e5cc6 100644 --- a/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.cu +++ b/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.cu @@ -76,6 +76,7 @@ void cuda_array_fill_double(double c, t_ndarray arr) arr.nd_double[i] = c; } +__device__ __host__ void device_memory(void** devPtr, size_t size) { cudaMalloc(devPtr, size); @@ -90,6 +91,44 @@ void host_memory(void** devPtr, size_t size) { cudaMallocHost(devPtr, size); } +// __device__ +// void shared_array_init(t_ndarray *arr) +// { +// switch (arr->type) +// { +// case nd_int8: +// arr->type_size = sizeof(int8_t); +// break; +// case nd_int16: +// arr->type_size = sizeof(int16_t); +// break; +// case nd_int32: +// arr->type_size = sizeof(int32_t); +// break; +// case nd_int64: +// arr->type_size = sizeof(int64_t); +// break; +// case nd_float: +// arr->type_size = sizeof(float); +// break; +// case nd_double: +// arr->type_size = sizeof(double); +// break; +// case nd_bool: +// arr->type_size = sizeof(bool); +// break; +// } +// arr->length = 1; +// for (int32_t i = 0; i < arr->nd; i++) +// arr->length *= arr->shape[i]; +// arr->buffer_size = arr->length * arr->type_size; +// for (int32_t i = 0; i < arr->nd; i++) +// { +// arr->strides[i] = 1; +// for (int32_t j = i + 1; j < arr->nd; j++) +// arr->strides[i] *= arr->shape[j]; +// } +// } t_ndarray cuda_array_create(int32_t nd, int64_t *shape, enum e_types type, bool is_view, enum e_memory_locations location) @@ -125,14 +164,14 @@ t_ndarray cuda_array_create(int32_t nd, int64_t *shape, } arr.is_view = is_view; arr.length = 1; - cudaMallocManaged(&(arr.shape), arr.nd * sizeof(int64_t)); + cudaMallocManaged(&arr.shape, arr.nd * sizeof(int64_t)); for (int32_t i = 0; i < arr.nd; i++) { arr.length *= shape[i]; arr.shape[i] = shape[i]; } arr.buffer_size = arr.length * arr.type_size; - cudaMallocManaged(&(arr.strides), nd * sizeof(int64_t)); + cudaMallocManaged(&arr.strides, nd * sizeof(int64_t)); for (int32_t i = 0; i < arr.nd; i++) { arr.strides[i] = 1; @@ -171,15 +210,15 @@ int32_t cuda_free(t_ndarray arr) return (1); } -__host__ __device__ -int32_t cuda_free_pointer(t_ndarray arr) -{ - if (arr.is_view == false || arr.shape == NULL) - return (0); - cudaFree(arr.shape); - arr.shape = NULL; - cudaFree(arr.strides); - arr.strides = NULL; - return (1); -} +// __host__ __device__ +// int32_t cuda_free_pointer(t_ndarray arr) +// { +// if (arr.is_view == false || arr.shape == NULL) +// return (0); +// cudaFree(arr.shape); +// arr.shape = NULL; +// cudaFree(arr.strides); +// arr.strides = NULL; +// return (1); +// } diff --git a/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.h b/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.h index 8e88ecd998..e1bfa3875f 100644 --- a/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.h +++ b/pyccel/stdlib/cuda_ndarrays/cuda_ndarrays.h @@ -13,15 +13,17 @@ __global__ void cuda_array_arange_double(t_ndarray arr, int start); __global__ -void _cuda_array_fill_int8(int8_t c, t_ndarray arr); +void cuda_array_fill_int8(int8_t c, t_ndarray arr); __global__ -void _cuda_array_fill_int32(int32_t c, t_ndarray arr); +void cuda_array_fill_int32(int32_t c, t_ndarray arr); __global__ -void _cuda_array_fill_int64(int64_t c, t_ndarray arr); +void cuda_array_fill_int64(int64_t c, t_ndarray arr); __global__ -void _cuda_array_fill_double(double c, t_ndarray arr); +void cuda_array_fill_double(double c, t_ndarray arr); t_ndarray cuda_array_create(int32_t nd, int64_t *shape, enum e_types type, bool is_view, enum e_memory_locations location); +__device__ +void shared_array_init(t_ndarray *arr); int32_t cuda_free_array(t_ndarray dump); int32_t cuda_free_host(t_ndarray arr); diff --git a/pyccel/stdlib/cuda_ndarrays/ho_cuda_ndarrays.h b/pyccel/stdlib/cuda_ndarrays/ho_cuda_ndarrays.h new file mode 100644 index 0000000000..392e223bb7 --- /dev/null +++ b/pyccel/stdlib/cuda_ndarrays/ho_cuda_ndarrays.h @@ -0,0 +1,133 @@ +#ifndef HO_CUDA_NDARRAYS_H +# define HO_CUDA_NDARRAYS_H + +#include "../ndarrays/ndarrays.h" +// CUDA runtime + + +__global__ +void cuda_array_arange_int8(t_ndarray arr, int start); +__global__ +void cuda_array_arange_int32(t_ndarray arr, int start); +__global__ +void cuda_array_arange_int64(t_ndarray arr, int start); +__global__ +void cuda_array_arange_double(t_ndarray arr, int start); + +__global__ +void cuda_array_fill_int8(int8_t c, t_ndarray arr); +__global__ +void cuda_array_fill_int32(int32_t c, t_ndarray arr); +__global__ +void cuda_array_fill_int64(int64_t c, t_ndarray arr); +__global__ +void cuda_array_fill_double(double c, t_ndarray arr); + +t_ndarray cuda_array_create(int32_t nd, int64_t *shape, enum e_types type, bool is_view, enum e_memory_locations location); +__device__ +void shared_array_init(t_ndarray *arr); +int32_t cuda_free_array(t_ndarray dump); + +int32_t cuda_free_host(t_ndarray arr); + +#ifdef HO_CUDA_PYCCEL +#include +#include +__device__ inline void shared_array_init(t_ndarray *arr) +{ + switch (arr->type) + { + case nd_int8: + arr->type_size = sizeof(int8_t); + break; + case nd_int16: + arr->type_size = sizeof(int16_t); + break; + case nd_int32: + arr->type_size = sizeof(int32_t); + break; + case nd_int64: + arr->type_size = sizeof(int64_t); + break; + case nd_float: + arr->type_size = sizeof(float); + break; + case nd_double: + arr->type_size = sizeof(double); + break; + case nd_bool: + arr->type_size = sizeof(bool); + break; + } + arr->length = 1; + for (int32_t i = 0; i < arr->nd; i++) + arr->length *= arr->shape[i]; + arr->buffer_size = arr->length * arr->type_size; + for (int32_t i = 0; i < arr->nd; i++) + { + arr->strides[i] = 1; + for (int32_t j = i + 1; j < arr->nd; j++) + arr->strides[i] *= arr->shape[j]; + } +} + +__device__ __host__ inline +t_slice cuda_new_slice(int32_t start, int32_t end, int32_t step, enum e_slice_type type) +{ + t_slice slice; + + slice.start = start; + slice.end = end; + slice.step = step; + slice.type = type; + return (slice); +} + +__device__ __host__ inline +t_ndarray cuda_array_slicing(t_ndarray arr, int n, t_slice slices[]) +{ + t_ndarray view; + t_slice slice; + int32_t start = 0; + int32_t j; + + view.nd = n; + view.type = arr.type; + view.type_size = arr.type_size; + view.shape = (int64_t *)malloc(sizeof(int64_t) * view.nd); + view.strides = (int64_t *)malloc(sizeof(int64_t) * view.nd); + view.is_view = true; + j = 0; + for (int32_t i = 0; i < arr.nd; i++) + { + slice = slices[i]; + if (slice.type == RANGE) + { + view.shape[j] = (slice.end - slice.start + (slice.step - 1)) / slice.step; + view.strides[j] = arr.strides[i] * slice.step; + j++; + } + start += slice.start * arr.strides[i]; + } + + view.raw_data = arr.raw_data + start * arr.type_size; + view.length = 1; + for (int32_t i = 0; i < view.nd; i++) + view.length *= view.shape[i]; + return (view); +} +__host__ __device__ inline +int32_t cuda_free_pointer(t_ndarray arr) +{ + if (arr.is_view == false || arr.shape == NULL) + return (0); + free(arr.shape); + arr.shape = NULL; + free(arr.strides); + arr.strides = NULL; + return (1); +} + +#endif +#undef HO_CUDA_PYCCEL +#endif \ No newline at end of file