|
3 | 3 | # SPDX-License-Identifier: Apache-2.0
|
4 | 4 |
|
5 | 5 | from llvmlite import ir as llvmir
|
| 6 | +from llvmlite.ir import Constant |
6 | 7 | from llvmlite.ir.types import DoubleType, FloatType
|
7 | 8 | from numba import types
|
| 9 | +from numba.core import cgutils |
| 10 | +from numba.core import config as numba_config |
8 | 11 | from numba.core.typing import signature
|
9 |
| -from numba.extending import intrinsic |
| 12 | +from numba.extending import intrinsic, overload_classmethod |
10 | 13 | from numba.np.arrayobj import (
|
11 |
| - _empty_nd_impl, |
12 | 14 | _parse_empty_args,
|
13 | 15 | _parse_empty_like_args,
|
14 | 16 | get_itemsize,
|
| 17 | + make_array, |
| 18 | + populate_array, |
15 | 19 | )
|
16 | 20 |
|
17 | 21 | from numba_dpex.core.runtime import context as dpexrt
|
| 22 | +from numba_dpex.core.types import DpnpNdArray |
| 23 | + |
| 24 | + |
| 25 | +def _empty_nd_impl(context, builder, arrtype, shapes): |
| 26 | + """Utility function used for allocating a new array during LLVM code |
| 27 | + generation (lowering). Given a target context, builder, array |
| 28 | + type, and a tuple or list of lowered dimension sizes, returns a |
| 29 | + LLVM value pointing at a Numba runtime allocated array. |
| 30 | + """ |
| 31 | + |
| 32 | + arycls = make_array(arrtype) |
| 33 | + ary = arycls(context, builder) |
| 34 | + |
| 35 | + datatype = context.get_data_type(arrtype.dtype) |
| 36 | + itemsize = context.get_constant(types.intp, get_itemsize(context, arrtype)) |
| 37 | + |
| 38 | + # compute array length |
| 39 | + arrlen = context.get_constant(types.intp, 1) |
| 40 | + overflow = Constant(llvmir.IntType(1), 0) |
| 41 | + for s in shapes: |
| 42 | + arrlen_mult = builder.smul_with_overflow(arrlen, s) |
| 43 | + arrlen = builder.extract_value(arrlen_mult, 0) |
| 44 | + overflow = builder.or_(overflow, builder.extract_value(arrlen_mult, 1)) |
| 45 | + |
| 46 | + if arrtype.ndim == 0: |
| 47 | + strides = () |
| 48 | + elif arrtype.layout == "C": |
| 49 | + strides = [itemsize] |
| 50 | + for dimension_size in reversed(shapes[1:]): |
| 51 | + strides.append(builder.mul(strides[-1], dimension_size)) |
| 52 | + strides = tuple(reversed(strides)) |
| 53 | + elif arrtype.layout == "F": |
| 54 | + strides = [itemsize] |
| 55 | + for dimension_size in shapes[:-1]: |
| 56 | + strides.append(builder.mul(strides[-1], dimension_size)) |
| 57 | + strides = tuple(strides) |
| 58 | + else: |
| 59 | + raise NotImplementedError( |
| 60 | + "Don't know how to allocate array with layout '{0}'.".format( |
| 61 | + arrtype.layout |
| 62 | + ) |
| 63 | + ) |
| 64 | + |
| 65 | + # Check overflow, numpy also does this after checking order |
| 66 | + allocsize_mult = builder.smul_with_overflow(arrlen, itemsize) |
| 67 | + allocsize = builder.extract_value(allocsize_mult, 0) |
| 68 | + overflow = builder.or_(overflow, builder.extract_value(allocsize_mult, 1)) |
| 69 | + |
| 70 | + with builder.if_then(overflow, likely=False): |
| 71 | + # Raise same error as numpy, see: |
| 72 | + # https://github.com/numpy/numpy/blob/2a488fe76a0f732dc418d03b452caace161673da/numpy/core/src/multiarray/ctors.c#L1095-L1101 # noqa: E501 |
| 73 | + context.call_conv.return_user_exc( |
| 74 | + builder, |
| 75 | + ValueError, |
| 76 | + ( |
| 77 | + "array is too big; `arr.size * arr.dtype.itemsize` is larger than" |
| 78 | + " the maximum possible size.", |
| 79 | + ), |
| 80 | + ) |
| 81 | + |
| 82 | + usm_ty = arrtype.usm_type |
| 83 | + usm_ty_val = 0 |
| 84 | + if usm_ty == "device": |
| 85 | + usm_ty_val = 1 |
| 86 | + elif usm_ty == "shared": |
| 87 | + usm_ty_val = 2 |
| 88 | + elif usm_ty == "host": |
| 89 | + usm_ty_val = 3 |
| 90 | + usm_type = context.get_constant(types.uint64, usm_ty_val) |
| 91 | + device = context.insert_const_string(builder.module, arrtype.device) |
| 92 | + |
| 93 | + args = ( |
| 94 | + context.get_dummy_value(), |
| 95 | + allocsize, |
| 96 | + usm_type, |
| 97 | + device, |
| 98 | + ) |
| 99 | + mip = types.MemInfoPointer(types.voidptr) |
| 100 | + arytypeclass = types.TypeRef(type(arrtype)) |
| 101 | + sig = signature( |
| 102 | + mip, |
| 103 | + arytypeclass, |
| 104 | + types.intp, |
| 105 | + types.uint64, |
| 106 | + types.voidptr, |
| 107 | + ) |
| 108 | + from numba_dpex.decorators import dpjit |
| 109 | + |
| 110 | + op = dpjit(_call_usm_allocator) |
| 111 | + fnop = context.typing_context.resolve_value_type(op) |
| 112 | + # The _call_usm_allocator function will be compiled and added to registry |
| 113 | + # when the get_call_type function is invoked. |
| 114 | + fnop.get_call_type(context.typing_context, sig.args, {}) |
| 115 | + eqfn = context.get_function(fnop, sig) |
| 116 | + meminfo = eqfn(builder, args) |
| 117 | + |
| 118 | + data = context.nrt.meminfo_data(builder, meminfo) |
| 119 | + |
| 120 | + intp_t = context.get_value_type(types.intp) |
| 121 | + shape_array = cgutils.pack_array(builder, shapes, ty=intp_t) |
| 122 | + strides_array = cgutils.pack_array(builder, strides, ty=intp_t) |
| 123 | + |
| 124 | + populate_array( |
| 125 | + ary, |
| 126 | + data=builder.bitcast(data, datatype.as_pointer()), |
| 127 | + shape=shape_array, |
| 128 | + strides=strides_array, |
| 129 | + itemsize=itemsize, |
| 130 | + meminfo=meminfo, |
| 131 | + ) |
| 132 | + |
| 133 | + return ary |
| 134 | + |
| 135 | + |
| 136 | +numba_config.DISABLE_PERFORMANCE_WARNINGS = 0 |
| 137 | + |
| 138 | + |
| 139 | +def _call_usm_allocator(arrtype, size, usm_type, device): |
| 140 | + """Trampoline to call the intrinsic used for allocation""" |
| 141 | + return arrtype._usm_allocate(size, usm_type, device) |
| 142 | + |
| 143 | + |
| 144 | +numba_config.DISABLE_PERFORMANCE_WARNINGS = 1 |
| 145 | + |
| 146 | + |
| 147 | +@overload_classmethod(DpnpNdArray, "_usm_allocate", target="dpex") |
| 148 | +def _ol_array_allocate(cls, allocsize, usm_type, device): |
| 149 | + """Implements an allocator for dpnp.ndarrays.""" |
| 150 | + |
| 151 | + def impl(cls, allocsize, usm_type, device): |
| 152 | + return intrin_usm_alloc(allocsize, usm_type, device) |
| 153 | + |
| 154 | + return impl |
18 | 155 |
|
19 | 156 |
|
20 | 157 | def alloc_empty_arrayobj(context, builder, sig, args, is_like=False):
|
|
0 commit comments