Skip to content

Commit 205b012

Browse files
[NVIDIA] Bunch triang functions (#998)
* Updated .gitignore * Added initial support for Acos, Acosh, Asin, Asinh * Added new separate fiel for single_layer_tests for activations * Added missed implementation of operators * Updated docs/cuda_opset.md --------- Co-authored-by: Denis Kotov <[email protected]>
1 parent 1554fd3 commit 205b012

File tree

33 files changed

+874
-7
lines changed

33 files changed

+874
-7
lines changed

modules/nvidia_plugin/.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1+
.idea/
2+
.env
3+
14
__pycache__/
5+
build/
6+
dist/
7+
*.egg-info/
8+
9+
report_api.xml
10+
report_op.xml

modules/nvidia_plugin/build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ BUILD_TYPE=${BUILD_TYPE:-Release}
1212
BUILD_TARGETS=${BUILD_TARGETS:-"ov_nvidia_func_tests ov_nvidia_unit_tests openvino_nvidia_gpu_plugin benchmark_app"}
1313
WHEEL_VERSION=${WHEEL_VERSION:-"2022.3.0"}
1414
ENABLE_TESTS=${ENABLE_TESTS:-"ON"}
15+
ENABLE_FUNCTIONAL_TESTS=${ENABLE_FUNCTIONAL_TESTS:-"OFF"}
1516

1617
[[ -n "${OPENVINO_HOME}" ]] || { echo "OPENVINO_HOME environment variable is expected"; exit 1; }
1718
[[ -n "${OPENVINO_CONTRIB}" ]] || { echo "OPENVINO_CONTRIB environment variable is expected"; exit 1; }
@@ -41,8 +42,10 @@ cmake "${OPENVINO_HOME}" \
4142
-DENABLE_NVIDIA=ON \
4243
-DENABLE_PLUGINS_XML=ON \
4344
-DENABLE_TESTS="${ENABLE_TESTS}" \
45+
-DENABLE_FUNCTIONAL_TESTS="${ENABLE_FUNCTIONAL_TESTS}" \
4446
-DBUILD_arm_plugin=OFF \
4547
-DBUILD_java_api=OFF \
48+
-DBUILD_llama_cpp_plugin=OFF \
4649
-DOPENVINO_EXTRA_MODULES="${OPENVINO_CONTRIB}/modules" \
4750
-DWHEEL_VERSION="${WHEEL_VERSION}" \
4851
-DVERBOSE_BUILD=ON \

modules/nvidia_plugin/docs/cuda_opset.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ The semantics match corresponding nGraph operation classes declared in `namespac
1010
| Layers | NVIDIA GPU plugin |
1111
|------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
1212
| [Abs](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Abs_1.md) | Supported |
13-
| [Acos](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Acos_1.md) | Not Supported |
14-
| [Acosh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Acosh_3.md) | Not Supported |
13+
| [Acos](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Acos_1.md) | Supported |
14+
| [Acosh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Acosh_3.md) | Supported |
1515
| [Add](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Add_1.md) | Supported* |
16-
| [Asin](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Asin_1.md) | Not Supported |
17-
| [Asinh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Asinh_3.md) | Not Supported |
16+
| [Asin](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Asin_1.md) | Supported |
17+
| [Asinh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Asinh_3.md) | Supported |
1818
| [Assign](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/infrastructure/Assign_3.md) | Not Supported |
19-
| [Atan](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Atan_1.md) | Not Supported |
20-
| [Atanh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Atanh_3.md) | Not Supported |
19+
| [Atan](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Atan_1.md) | Supported |
20+
| [Atanh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Atanh_3.md) | Supported |
2121
| [AvgPool](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/pooling/AvgPool_1.md) | Supported |
2222
| [BatchNormInference](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/normalization/BatchNormInference_5.md) | Not Supported |
2323
| [BatchToSpace](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/BatchToSpace_2.md) | Not Supported |
@@ -148,7 +148,7 @@ The semantics match corresponding nGraph operation classes declared in `namespac
148148
| [StridedSlice](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/StridedSlice_1.md) | Supported |
149149
| [Subtract](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Subtract_1.md) | Supported |
150150
| [Swish](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/activation/Swish_4.md) | Supported |
151-
| [Tan](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Tan_1.md) | Not Supported |
151+
| [Tan](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Tan_1.md) | Supported |
152152
| [Tanh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Tanh_1.md) | Supported |
153153
| [TensorIterator](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/infrastructure/TensorIterator_1.md) | Supported |
154154
| [Tile](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/Tile_1.md) | Not Supported |

modules/nvidia_plugin/src/cuda/math.cuh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ inline __device__ T abs(T a) {
9696
return static_cast<T>(::fabsf(static_cast<float>(a)));
9797
}
9898

99+
template <typename T>
100+
inline __device__ T tan(T a) {
101+
return static_cast<T>(::tanf(static_cast<float>(a)));
102+
}
103+
99104
template <typename T>
100105
inline __device__ T tanh(T a) {
101106
return static_cast<T>(::tanhf(static_cast<float>(a)));
@@ -116,6 +121,16 @@ inline __device__ T sinh(T a) {
116121
return static_cast<T>(::sinhf(static_cast<float>(a)));
117122
}
118123

124+
template <typename T>
125+
inline __device__ T asin(T a) {
126+
return static_cast<T>(::asinf(static_cast<float>(a)));
127+
}
128+
129+
template <typename T>
130+
inline __device__ T asinh(T a) {
131+
return static_cast<T>(::asinhf(static_cast<float>(a)));
132+
}
133+
119134
template <typename T>
120135
inline __device__ T cos(T a) {
121136
return static_cast<T>(::cosf(static_cast<float>(a)));
@@ -126,6 +141,26 @@ inline __device__ T cosh(T a) {
126141
return static_cast<T>(::coshf(static_cast<float>(a)));
127142
}
128143

144+
template <typename T>
145+
inline __device__ T acos(T a) {
146+
return static_cast<T>(::acosf(static_cast<float>(a)));
147+
}
148+
149+
template <typename T>
150+
inline __device__ T acosh(T a) {
151+
return static_cast<T>(::acoshf(static_cast<float>(a)));
152+
}
153+
154+
template <typename T>
155+
inline __device__ T atan(T a) {
156+
return static_cast<T>(::atanf(static_cast<float>(a)));
157+
}
158+
159+
template <typename T>
160+
inline __device__ T atanh(T a) {
161+
return static_cast<T>(::atanhf(static_cast<float>(a)));
162+
}
163+
129164
template <typename T>
130165
inline __device__ T log(T a) {
131166
return static_cast<T>(::logf(static_cast<float>(a)));
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (C) 2021-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "acos.hpp"
6+
7+
namespace ov {
8+
namespace nvidia_gpu {
9+
namespace kernel {
10+
11+
namespace cumath = CUDA::math;
12+
13+
template <typename T>
14+
struct AcosOpImpl {
15+
__device__ static inline T op(T x) {
16+
return cumath::acos(x);
17+
}
18+
};
19+
20+
Acos::Acos(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
21+
: impl_{element_type, max_threads_per_block, num_elements} {}
22+
23+
void Acos::operator()(cudaStream_t stream, const void* in0, void* out) const {
24+
impl_(stream, in0, out);
25+
}
26+
27+
} // namespace kernel
28+
} // namespace nvidia_gpu
29+
} // namespace ov
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (C) 2021-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "details/cuda_type_traits.hpp"
8+
#include "details/elementwise_unary.cuh"
9+
10+
namespace ov {
11+
namespace nvidia_gpu {
12+
namespace kernel {
13+
14+
template <typename T>
15+
struct AcosOpImpl;
16+
/**
17+
* Elementwise Acos operation
18+
*/
19+
class Acos {
20+
public:
21+
Acos(Type_t element_type, size_t max_threads_per_block, size_t num_elements);
22+
23+
void operator()(cudaStream_t stream, const void* in0, void* out) const;
24+
25+
private:
26+
ElementwiseUnary<AllElementTypesSwitch, AcosOpImpl> impl_;
27+
};
28+
29+
} // namespace kernel
30+
} // namespace nvidia_gpu
31+
} // namespace ov
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (C) 2021-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "acosh.hpp"
6+
7+
namespace ov {
8+
namespace nvidia_gpu {
9+
namespace kernel {
10+
11+
namespace cumath = CUDA::math;
12+
13+
template <typename T>
14+
struct AcoshOpImpl {
15+
__device__ static inline T op(T x) {
16+
return cumath::acosh(x);
17+
}
18+
};
19+
20+
Acosh::Acosh(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
21+
: impl_{element_type, max_threads_per_block, num_elements} {}
22+
23+
void Acosh::operator()(cudaStream_t stream, const void* in0, void* out) const {
24+
impl_(stream, in0, out);
25+
}
26+
27+
} // namespace kernel
28+
} // namespace nvidia_gpu
29+
} // namespace ov
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (C) 2021-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "details/cuda_type_traits.hpp"
8+
#include "details/elementwise_unary.cuh"
9+
10+
namespace ov {
11+
namespace nvidia_gpu {
12+
namespace kernel {
13+
14+
template <typename T>
15+
struct AcoshOpImpl;
16+
/**
17+
* Elementwise Acosh operation
18+
*/
19+
class Acosh {
20+
public:
21+
Acosh(Type_t element_type, size_t max_threads_per_block, size_t num_elements);
22+
23+
void operator()(cudaStream_t stream, const void* in0, void* out) const;
24+
25+
private:
26+
ElementwiseUnary<AllElementTypesSwitch, AcoshOpImpl> impl_;
27+
};
28+
29+
} // namespace kernel
30+
} // namespace nvidia_gpu
31+
} // namespace ov
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (C) 2021-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "asin.hpp"
6+
7+
namespace ov {
8+
namespace nvidia_gpu {
9+
namespace kernel {
10+
11+
namespace cumath = CUDA::math;
12+
13+
template <typename T>
14+
struct AsinOpImpl {
15+
__device__ static inline T op(T x) {
16+
return cumath::asin(x);
17+
}
18+
};
19+
20+
Asin::Asin(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
21+
: impl_{element_type, max_threads_per_block, num_elements} {}
22+
23+
void Asin::operator()(cudaStream_t stream, const void* in0, void* out) const {
24+
impl_(stream, in0, out);
25+
}
26+
27+
} // namespace kernel
28+
} // namespace nvidia_gpu
29+
} // namespace ov
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (C) 2021-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "details/cuda_type_traits.hpp"
8+
#include "details/elementwise_unary.cuh"
9+
10+
namespace ov {
11+
namespace nvidia_gpu {
12+
namespace kernel {
13+
14+
template <typename T>
15+
struct AsinOpImpl;
16+
/**
17+
* Elementwise asin
18+
*/
19+
class Asin {
20+
public:
21+
Asin(Type_t element_type, size_t max_threads_per_block, size_t num_elements);
22+
23+
void operator()(cudaStream_t stream, const void* in0, void* out) const;
24+
25+
private:
26+
ElementwiseUnary<AllElementTypesSwitch, AsinOpImpl> impl_;
27+
};
28+
29+
} // namespace kernel
30+
} // namespace nvidia_gpu
31+
} // namespace ov

0 commit comments

Comments
 (0)