@@ -79,31 +79,28 @@ class OperatorBase {
79
79
80
80
virtual ~OperatorBase () {}
81
81
82
- template <typename T>
83
- inline const T& Attr (const std::string& name) const {
84
- PADDLE_ENFORCE (attrs_.count (name) != 0 , " %s should be in AttributeMap" ,
85
- name);
86
- return boost::get<T>(attrs_.at (name));
87
- }
88
-
89
- // / if scope is not null, also show dimensions of arguments
90
- virtual std::string DebugStringEx (const Scope* scope) const ;
91
-
92
- std::string DebugString () const { return DebugStringEx (nullptr ); }
93
-
94
- // / Net will call this interface function to Run an op.
82
+ // / Executor will call this interface function to Run an op.
95
83
// The implementation should be written at RunImpl
96
84
void Run (const Scope& scope, const platform::Place& place);
97
85
98
86
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
99
87
virtual void Stop () {}
100
88
101
- virtual bool IsNetOp () const { return false ; }
89
+ // / if scope is not null, also show dimensions of arguments
90
+ virtual std::string DebugStringEx (const Scope* scope) const ;
91
+ std::string DebugString () const { return DebugStringEx (nullptr ); }
102
92
103
93
virtual bool SupportGPU () const { return false ; }
104
94
105
- // / rename inputs outputs name
106
- void Rename (const std::string& old_name, const std::string& new_name);
95
+ const std::string& Type () const { return type_; }
96
+
97
+ template <typename T>
98
+ inline const T& Attr (const std::string& name) const {
99
+ PADDLE_ENFORCE (attrs_.count (name) != 0 , " %s should be in AttributeMap" ,
100
+ name);
101
+ return boost::get<T>(attrs_.at (name));
102
+ }
103
+ const AttributeMap& Attrs () const { return attrs_; }
107
104
108
105
const VariableNameMap& Inputs () const { return inputs_; }
109
106
const VariableNameMap& Outputs () const { return outputs_; }
@@ -112,21 +109,17 @@ class OperatorBase {
112
109
std::string Input (const std::string& name) const ;
113
110
// ! Get a input which has multiple variables.
114
111
const std::vector<std::string>& Inputs (const std::string& name) const ;
115
-
112
+ // ! Get all inputs variable names
116
113
std::vector<std::string> InputVars () const ;
117
114
118
115
// ! Get a output with argument's name described in `op_proto`
119
116
std::string Output (const std::string& name) const ;
120
117
// ! Get an output which has multiple variables.
121
118
// ! TODO add a vector_view to prevent memory copy.
122
119
const std::vector<std::string>& Outputs (const std::string& name) const ;
123
-
120
+ // ! Get all outputs variable names
124
121
virtual std::vector<std::string> OutputVars (bool has_intermediate) const ;
125
122
126
- const std::string& Type () const { return type_; }
127
- void SetType (const std::string& type) { type_ = type; }
128
- const AttributeMap& Attrs () const { return attrs_; }
129
-
130
123
// Return a new operator instance, which is as same as this.
131
124
// Use unique_ptr to prevent caller forget to delete this pointer.
132
125
virtual std::unique_ptr<OperatorBase> Clone () const = 0;
@@ -278,20 +271,6 @@ class ExecutionContext {
278
271
return res;
279
272
}
280
273
281
- void ShareLoD (const std::string& in, const std::string& out, size_t i = 0 ,
282
- size_t j = 0 ) const {
283
- PADDLE_ENFORCE_LT (i, InputSize (in));
284
- PADDLE_ENFORCE_LT (j, OutputSize (out));
285
- auto * in_var = MultiInputVar (in)[i];
286
- auto * out_var = MultiOutputVar (out)[j];
287
- if (!in_var->IsType <LoDTensor>()) return ;
288
- PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
289
- " The %d-th output of Output(%s) must be LoDTensor." , j, out);
290
- auto in_tensor = in_var->Get <LoDTensor>();
291
- auto * out_tensor = out_var->GetMutable <LoDTensor>();
292
- out_tensor->set_lod (in_tensor.lod ());
293
- }
294
-
295
274
platform::Place GetPlace () const { return device_context_.GetPlace (); }
296
275
297
276
template <typename DeviceContextType>
0 commit comments