Skip to content

Commit 62b015a

Browse files
Merge master into not_equal_impl
2 parents 4a5a84a + 1596a13 commit 62b015a

32 files changed

+5447
-72
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ include dpctl/program/_program_api.h
1818
include dpctl/tensor/_usmarray.h
1919
include dpctl/tensor/_usmarray_api.h
2020
recursive-include dpctl/tensor/include *
21+
recursive-include dpctl/tensor/libtensor/include *
2122
include dpctl/tests/input_files/*
2223
include dpctl/tests/*.pyx

cmake/FindDpctl.cmake

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ if(NOT Dpctl_FOUND)
4040
OUTPUT_STRIP_TRAILING_WHITESPACE
4141
ERROR_QUIET
4242
)
43-
4443
endif()
4544
endif()
4645

@@ -49,6 +48,12 @@ find_path(Dpctl_INCLUDE_DIR
4948
PATHS "${_dpctl_include_dir}" "${PYTHON_INCLUDE_DIR}"
5049
PATH_SUFFIXES dpctl/include
5150
)
51+
get_filename_component(_dpctl_dir ${_dpctl_include_dir} DIRECTORY)
52+
53+
find_path(Dpctl_TENSOR_INCLUDE_DIR
54+
kernels utils
55+
PATHS "${_dpctl_dir}/tensor/libtensor/include"
56+
)
5257

5358
set(Dpctl_INCLUDE_DIRS ${Dpctl_INCLUDE_DIR})
5459

@@ -57,8 +62,9 @@ set(Dpctl_INCLUDE_DIRS ${Dpctl_INCLUDE_DIR})
5762
include(FindPackageHandleStandardArgs)
5863
find_package_handle_standard_args(Dpctl
5964
REQUIRED_VARS
60-
Dpctl_INCLUDE_DIR
65+
Dpctl_INCLUDE_DIR Dpctl_TENSOR_INCLUDE_DIR
6166
VERSION_VAR Dpctl_VERSION
6267
)
6368

6469
mark_as_advanced(Dpctl_INCLUDE_DIR)
70+
mark_as_advanced(Dpctl_TENSOR_INCLUDE_DIR)

dpctl/__main__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def print_includes() -> None:
3535
print("-I " + dpctl.get_include())
3636

3737

38+
def print_tensor_includes() -> None:
39+
"Prints include flags for dpctl and SyclInterface library"
40+
dpctl_dir = _dpctl_dir()
41+
libtensor_dir = os.path.join(dpctl_dir, "tensor", "libtensor", "include")
42+
print("-I " + libtensor_dir)
43+
44+
3845
def print_cmake_dir() -> None:
3946
"Prints directory with FindDpctl.cmake"
4047
dpctl_dir = _dpctl_dir()
@@ -75,7 +82,12 @@ def main() -> None:
7582
parser.add_argument(
7683
"--includes",
7784
action="store_true",
78-
help="Include flags dpctl headers.",
85+
help="Include flags for dpctl headers.",
86+
)
87+
parser.add_argument(
88+
"--tensor-includes",
89+
action="store_true",
90+
help="Include flags for dpctl libtensor headers.",
7991
)
8092
parser.add_argument(
8193
"--cmakedir",
@@ -128,6 +140,8 @@ def main() -> None:
128140
return
129141
if args.includes:
130142
print_includes()
143+
if args.tensor_includes:
144+
print_tensor_includes()
131145
if args.cmakedir:
132146
print_cmake_dir()
133147
if args.library:

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pybind11_add_module(${python_module_name} MODULE
4747
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
4848
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
50+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
5051
)
5152
set(_clang_prefix "")
5253
if (WIN32)

dpctl/tensor/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,27 @@
9494
from ._elementwise_funcs import (
9595
abs,
9696
add,
97+
conj,
9798
cos,
9899
divide,
99100
equal,
101+
exp,
102+
expm1,
103+
imag,
100104
isfinite,
101105
isinf,
102106
isnan,
107+
log,
108+
log1p,
103109
multiply,
104110
not_equal,
111+
proj,
112+
real,
113+
sin,
105114
sqrt,
106115
subtract,
107116
)
117+
from ._reduction import sum
108118

109119
__all__ = [
110120
"Device",
@@ -183,14 +193,24 @@
183193
"inf",
184194
"abs",
185195
"add",
196+
"conj",
186197
"cos",
198+
"exp",
199+
"expm1",
200+
"imag",
187201
"isinf",
188202
"isnan",
189203
"isfinite",
204+
"log",
205+
"log1p",
206+
"proj",
207+
"real",
208+
"sin",
190209
"sqrt",
191210
"divide",
192211
"multiply",
193212
"subtract",
194213
"equal",
195214
"not_equal",
215+
"sum",
196216
]

dpctl/tensor/_elementwise_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
255255
raise ValueError
256256
o1_kind_num = _weak_type_num_kind(o1_dtype)
257257
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
258-
if o1_kind_num > o2_kind_num:
258+
if o1_kind_num > o2_kind_num or o1_kind_num == 2:
259259
if isinstance(o1_dtype, WeakBooleanType):
260260
return dpt.bool, o2_dtype
261261
if isinstance(o1_dtype, WeakIntegralType):
@@ -273,7 +273,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
273273
):
274274
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
275275
o2_kind_num = _weak_type_num_kind(o2_dtype)
276-
if o2_kind_num > o1_kind_num:
276+
if o2_kind_num > o1_kind_num or o2_kind_num == 2:
277277
if isinstance(o2_dtype, WeakBooleanType):
278278
return o1_dtype, dpt.bool
279279
if isinstance(o2_dtype, WeakIntegralType):

0 commit comments

Comments
 (0)