Skip to content

Commit 777d1a4

Browse files
chenwhqlzhwesky2010JiabinYangShixiaowei02
authored
[Cherry-pick] The 4th part of new custom op (#31282)
* modify custom op dependent from paddle_framework to paddle_custom_op (#31195) * [Custom Op] Remove unsupport dtypes (#31232) * remove remove_unsupport_dtype * remove remove_unsupport_dtype * remove test dtype * add more include * change dtype.h's enum as enum class to avoid conflict with inference lib * make enum as enum class * remove additional test * merge develop * polish code * [Custom OP] Support stream set on Custom Op (#31257) * [Custom OP] change the user header file format, test=develop (#31274) * [Custom OP]add PD_THROW and PD_CHECK for User Error message (#31253) * [Custom OP]add PD_THROW and PD_CHECK for User error message * PD_THROW and PD_CHECK, fix comment * fix Windows error message * fix Windows error message * fix CI * [Custom OP]add MSVC compile check on Windows (#31265) * fix test_check_abi Co-authored-by: Zhou Wei <[email protected]> Co-authored-by: Jiabin Yang <[email protected]> Co-authored-by: 石晓伟 <[email protected]> Co-authored-by: zhouwei25 <[email protected]>
1 parent f4a69d5 commit 777d1a4

32 files changed

+645
-536
lines changed

cmake/inference_lib.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ copy(inference_lib_dist
189189
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/crypto/)
190190
include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
191191

192+
copy(inference_lib_dist
193+
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/*
194+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
195+
192196
# CAPI inference library for only inference
193197
set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING
194198
"A path setting CAPI paddle inference shared")

paddle/extension.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ limitations under the License. */
1515
#pragma once
1616

1717
// All paddle apis in C++ frontend
18-
#include "paddle/fluid/extension/include/all.h"
18+
#include "paddle/fluid/extension/include/ext_all.h"

paddle/fluid/extension/include/dispatch.h

Lines changed: 0 additions & 168 deletions
This file was deleted.

paddle/fluid/extension/include/all.h renamed to paddle/fluid/extension/include/ext_all.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ limitations under the License. */
2424
#endif
2525
#endif
2626

27-
#include "paddle/fluid/extension/include/dispatch.h"
28-
#include "paddle/fluid/extension/include/dtype.h"
29-
#include "paddle/fluid/extension/include/op_meta_info.h"
30-
#include "paddle/fluid/extension/include/place.h"
31-
#include "paddle/fluid/extension/include/tensor.h"
27+
#include "ext_dispatch.h" // NOLINT
28+
#include "ext_dtype.h" // NOLINT
29+
#include "ext_exception.h" // NOLINT
30+
#include "ext_op_meta_info.h" // NOLINT
31+
#include "ext_place.h" // NOLINT
32+
#include "ext_tensor.h" // NOLINT
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "ext_dtype.h" // NOLINT
18+
#include "ext_exception.h" // NOLINT
19+
20+
namespace paddle {
21+
22+
///////// Basic Marco ///////////
23+
24+
#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \
25+
case enum_type: { \
26+
using HINT = type; \
27+
__VA_ARGS__(); \
28+
break; \
29+
}
30+
31+
#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
32+
PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
33+
34+
///////// Floating Dispatch Marco ///////////
35+
36+
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
37+
[&] { \
38+
const auto& __dtype__ = TYPE; \
39+
switch (__dtype__) { \
40+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
41+
__VA_ARGS__) \
42+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
43+
__VA_ARGS__) \
44+
default: \
45+
PD_THROW("function " #NAME " is not implemented for data type `", \
46+
::paddle::ToString(__dtype__), "`"); \
47+
} \
48+
}()
49+
50+
///////// Integral Dispatch Marco ///////////
51+
52+
#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
53+
[&] { \
54+
const auto& __dtype__ = TYPE; \
55+
switch (__dtype__) { \
56+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
57+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
58+
__VA_ARGS__) \
59+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
60+
__VA_ARGS__) \
61+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
62+
__VA_ARGS__) \
63+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
64+
__VA_ARGS__) \
65+
default: \
66+
PD_THROW("function " #NAME " is not implemented for data type `" + \
67+
::paddle::ToString(__dtype__) + "`"); \
68+
} \
69+
}()
70+
71+
///////// Floating and Integral Dispatch Marco ///////////
72+
73+
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
74+
[&] { \
75+
const auto& __dtype__ = TYPE; \
76+
switch (__dtype__) { \
77+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
78+
__VA_ARGS__) \
79+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
80+
__VA_ARGS__) \
81+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
82+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
83+
__VA_ARGS__) \
84+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
85+
__VA_ARGS__) \
86+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
87+
__VA_ARGS__) \
88+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
89+
__VA_ARGS__) \
90+
default: \
91+
PD_THROW("function " #NAME " is not implemented for data type `" + \
92+
::paddle::ToString(__dtype__) + "`"); \
93+
} \
94+
}()
95+
96+
// TODO(chenweihang): Add more Marcos in the future if needed
97+
98+
} // namespace paddle

paddle/fluid/extension/include/dtype.h renamed to paddle/fluid/extension/include/ext_dtype.h

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,24 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
1514
#pragma once
1615

17-
#include "paddle/fluid/platform/bfloat16.h"
18-
#include "paddle/fluid/platform/complex128.h"
19-
#include "paddle/fluid/platform/complex64.h"
20-
#include "paddle/fluid/platform/float16.h"
16+
#include <cstdint>
17+
#include <string>
2118

22-
namespace paddle {
19+
#include "ext_exception.h" // NOLINT
2320

24-
using float16 = paddle::platform::float16;
25-
using bfloat16 = paddle::platform::bfloat16;
26-
using complex64 = paddle::platform::complex64;
27-
using complex128 = paddle::platform::complex128;
21+
namespace paddle {
2822

29-
enum DataType {
23+
enum class DataType {
3024
BOOL,
3125
INT8,
3226
UINT8,
3327
INT16,
3428
INT32,
3529
INT64,
36-
FLOAT16,
37-
BFLOAT16,
3830
FLOAT32,
3931
FLOAT64,
40-
COMPLEX64,
41-
COMPLEX128,
4232
// TODO(JiabinYang) support more data types if needed.
4333
};
4434

@@ -56,36 +46,24 @@ inline std::string ToString(DataType dtype) {
5646
return "int32_t";
5747
case DataType::INT64:
5848
return "int64_t";
59-
case DataType::FLOAT16:
60-
return "float16";
61-
case DataType::BFLOAT16:
62-
return "bfloat16";
6349
case DataType::FLOAT32:
6450
return "float";
6551
case DataType::FLOAT64:
6652
return "double";
67-
case DataType::COMPLEX64:
68-
return "complex64";
69-
case DataType::COMPLEX128:
70-
return "complex128";
7153
default:
72-
throw std::runtime_error("Unsupported paddle enum data type.");
54+
PD_THROW("Unsupported paddle enum data type.");
7355
}
7456
}
7557

76-
#define PD_FOR_EACH_DATA_TYPE(_) \
77-
_(bool, DataType::BOOL) \
78-
_(int8_t, DataType::INT8) \
79-
_(uint8_t, DataType::UINT8) \
80-
_(int16_t, DataType::INT16) \
81-
_(int, DataType::INT32) \
82-
_(int64_t, DataType::INT64) \
83-
_(float16, DataType::FLOAT16) \
84-
_(bfloat16, DataType::BFLOAT16) \
85-
_(float, DataType::FLOAT32) \
86-
_(double, DataType::FLOAT64) \
87-
_(complex64, DataType::COMPLEX64) \
88-
_(complex128, DataType::COMPLEX128)
58+
#define PD_FOR_EACH_DATA_TYPE(_) \
59+
_(bool, DataType::BOOL) \
60+
_(int8_t, DataType::INT8) \
61+
_(uint8_t, DataType::UINT8) \
62+
_(int16_t, DataType::INT16) \
63+
_(int, DataType::INT32) \
64+
_(int64_t, DataType::INT64) \
65+
_(float, DataType::FLOAT32) \
66+
_(double, DataType::FLOAT64)
8967

9068
template <paddle::DataType T>
9169
struct DataTypeToCPPType;

0 commit comments

Comments
 (0)