@@ -30,29 +30,24 @@ void FullKernel(const Context& dev_ctx,
30
30
31
31
class BinaryOperator : public HpuOperator {
32
32
public:
33
- BinaryOperator (std::string guid_prefix,
34
- std::string node_name,
35
- bool in_place = false )
36
- : HpuOperator(guid_prefix), pName_(node_name) {
37
- inPlace_ = in_place;
38
- }
33
+ explicit BinaryOperator (std::string guid_prefix) : HpuOperator(guid_prefix) {}
39
34
40
35
void AddNode (const std::vector<DIMS>& ins,
41
36
const std::vector<DIMS>& outs,
42
- synDataType datatype) {
37
+ synDataType datatype,
38
+ bool in_place = false ) {
43
39
assert (ins.size () == 2 && " input size should be 2" );
44
40
assert (outs.size () == 1 && " output size should be 1" );
45
41
46
- synSectionHandle section = nullptr ;
47
- if (inPlace_) {
48
- section = createSection ();
49
- }
42
+ synSectionHandle section = in_place ? createSection () : nullptr ;
50
43
51
44
synTensor inputs[ins.size ()] = {
52
45
createTensor (ins[0 ].size (), datatype, ins[0 ], true , " x" , section),
53
46
createTensor (ins[1 ].size (), datatype, ins[1 ], true , " y" )};
54
47
synTensor outputs[outs.size ()] = {createTensor (
55
48
outs[0 ].size (), datatype, outs[0 ], true , " output" , section)};
49
+
50
+ guid_ = guid_ + SynDataTypeToStr (datatype);
56
51
synStatus status = synNodeCreate (graphHandle_,
57
52
inputs,
58
53
outputs,
@@ -61,62 +56,61 @@ class BinaryOperator : public HpuOperator {
61
56
nullptr ,
62
57
0 ,
63
58
guid_.c_str (),
64
- pName_. c_str () ,
59
+ " bianary " ,
65
60
nullptr ,
66
61
nullptr );
67
62
PD_CHECK (status == synSuccess,
68
63
" [RUNTIME] synNodeCreate binary fwd () failed = %d" ,
69
64
status);
70
65
}
71
- std::string pName_;
72
- bool inPlace_;
73
66
};
74
67
75
- #define BINARY_RAW_KERNEL (kernel_func, node_name ) \
76
- template <typename T, typename Context> \
77
- void kernel_func##RawKernel(const Context& dev_ctx, \
78
- const phi::DenseTensor& x, \
79
- const phi::DenseTensor& y, \
80
- int axis, \
81
- phi::DenseTensor* out) { \
82
- dev_ctx.template Alloc <T>(out); \
83
- VLOG (6 ) << " CALL HPU " << #kernel_func << " RawKernel" ; \
84
- std::vector<int64_t > x_dim = phi::vectorize<int64_t >(x.dims ()); \
85
- std::vector<int64_t > y_dim = phi::vectorize<int64_t >(y.dims ()); \
86
- if (y_dim.size () == 0 ) { \
87
- y_dim.push_back (1 ); \
88
- } \
89
- if (x_dim.size () == 0 ) { \
90
- x_dim.push_back (1 ); \
91
- } \
92
- bool in_place = (x.data () == out->data ()); \
93
- std::vector<int64_t > outputs_dim = phi::vectorize<int64_t >(out->dims ()); \
94
- if (outputs_dim.size () == 0 ) { \
95
- outputs_dim.push_back (1 ); \
96
- } \
97
- OpCacheOperator op_info; \
98
- op_info.prepareOpInfo <T, nullptr_t >( \
99
- #node_name " _fwd" , {x_dim, y_dim}, nullptr ); \
100
- auto recipe = op_info.GetRecipe (); \
101
- \
102
- if (recipe == nullptr ) { \
103
- std::string op_node_name = in_place ? " _" #node_name : #node_name; \
104
- BinaryOperator op (op_info.guid_ , op_node_name, in_place); \
105
- op.AddNode ({x_dim, y_dim}, {outputs_dim}, op_info.datatype_ ); \
106
- op.Compile (); \
107
- op_info.setOp (op); \
108
- recipe = op_info.GetRecipe (); \
109
- } \
110
- \
111
- std::map<std::string, uint64_t > tensors; \
112
- tensors[" x" ] = reinterpret_cast <uint64_t >(x.data <T>()); \
113
- tensors[" y" ] = reinterpret_cast <uint64_t >(y.data <T>()); \
114
- tensors[" output" ] = reinterpret_cast <uint64_t >(out->data <T>()); \
115
- \
116
- RecipeRunner runner (recipe); \
117
- runner.Run (reinterpret_cast <C_Stream>(dev_ctx.stream ()), tensors); \
118
- \
119
- return ; \
68
+ #define BINARY_RAW_KERNEL (kernel_func, node_name ) \
69
+ template <typename T, typename Context> \
70
+ void kernel_func##RawKernel(const Context& dev_ctx, \
71
+ const phi::DenseTensor& x, \
72
+ const phi::DenseTensor& y, \
73
+ int axis, \
74
+ phi::DenseTensor* out) { \
75
+ dev_ctx.template Alloc <T>(out); \
76
+ VLOG (6 ) << " CALL HPU " << #kernel_func << " RawKernel" ; \
77
+ std::vector<int64_t > x_dim = phi::vectorize<int64_t >(x.dims ()); \
78
+ std::vector<int64_t > y_dim = phi::vectorize<int64_t >(y.dims ()); \
79
+ if (y_dim.size () == 0 ) { \
80
+ y_dim.push_back (1 ); \
81
+ } \
82
+ if (x_dim.size () == 0 ) { \
83
+ x_dim.push_back (1 ); \
84
+ } \
85
+ bool in_place = (x.data () == out->data ()); \
86
+ std::vector<int64_t > outputs_dim = phi::vectorize<int64_t >(out->dims ()); \
87
+ if (outputs_dim.size () == 0 ) { \
88
+ outputs_dim.push_back (1 ); \
89
+ } \
90
+ OpCacheOperator op_info; \
91
+ op_info.prepareOpInfo <T, nullptr_t >( \
92
+ in_place ? (std::string (#node_name) + " _" ) : std::string (#node_name), \
93
+ {x_dim, y_dim}, \
94
+ nullptr ); \
95
+ auto recipe = op_info.GetRecipe (); \
96
+ \
97
+ if (recipe == nullptr ) { \
98
+ BinaryOperator op (std::string (#node_name) + " _" ); \
99
+ op.AddNode ({x_dim, y_dim}, {outputs_dim}, op_info.datatype_ , in_place); \
100
+ op.Compile (); \
101
+ op_info.setOp (op); \
102
+ recipe = op_info.GetRecipe (); \
103
+ } \
104
+ \
105
+ std::map<std::string, uint64_t > tensors; \
106
+ tensors[" x" ] = reinterpret_cast <uint64_t >(x.data <T>()); \
107
+ tensors[" y" ] = reinterpret_cast <uint64_t >(y.data <T>()); \
108
+ tensors[" output" ] = reinterpret_cast <uint64_t >(out->data <T>()); \
109
+ \
110
+ RecipeRunner runner (recipe); \
111
+ runner.Run (reinterpret_cast <C_Stream>(dev_ctx.stream ()), tensors); \
112
+ \
113
+ return ; \
120
114
}
121
115
122
116
#define BINARY_KERNEL (kernel_func ) \
0 commit comments