Skip to content

Commit 0b3d8fc

Browse files
authored
Feature/op standard (#12860)
* new doc * standard
1 parent 9ee698e commit 0b3d8fc

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

doc/fluid/dev/new_op_cn.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,29 @@ $$Out = scale*X$$
119119
120120
这个例子有`AddAttr<AttrType>("scale", "...").SetDefault(1.0);` : 增加`scale`系数,作为参数属性,并且设置默认值为1.0。
121121
122+
### 定义GradProtoMaker类
123+
每个Op的必须有一个对应的GraProtoMaker,若未定制对应前向Op的GradProtoMaker,fluid提供了DefaultGradProtoMaker,默认注册会使用全部输入输出,包括Input, Output, Output@Grad等,使用不需要的变量的会造成显存浪费。
124+
下面示例定义了ScaleOp的GradProtoMaker。
125+
126+
```cpp
127+
class ScaleGradMaker : public framework::SingleGradOpDescMaker {
128+
public:
129+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
130+
131+
std::unique_ptr<framework::OpDesc> Apply() const override {
132+
auto *grad_op = new framework::OpDesc();
133+
grad_op->SetType("scale");
134+
grad_op->SetInput("X", OutputGrad("Out"));
135+
grad_op->SetOutput("Out", InputGrad("X"));
136+
grad_op->SetAttr("scale", GetAttr("scale"));
137+
return std::unique_ptr<framework::OpDesc>(grad_op);
138+
}
139+
};
140+
```
122141
123142
### 定义Operator类
124143
125-
下面的点实现了MulOp的定义
144+
下面实现了MulOp的定义
126145
127146
```cpp
128147
class MulOp : public framework::OperatorWithKernel {
@@ -383,6 +402,19 @@ PADDLE_ENFORCE(forward_pd != nullptr,
383402
"Fail to find eltwise_fwd_pd in device context"); //eltwise_fwd_pd用户可能看不懂
384403
```
385404

405+
3. OP内部调用非法接口:Op内部如果出现Output = ShareDataWith(Input)
406+
问题示例:
407+
```cpp
408+
auto *out = ctx.Output<framework::LoDTensor>("Out");
409+
auto *in = ctx.Input<framework::LoDTensor>("X");
410+
out->ShareDataWith(*in);
411+
```
412+
Op内部如果出现Output = ShareDataWith(Input),相当于operator图的中有一条隐藏边,连接了Input和Output,这条边无法在图分析中表达,引发基于图优化的错误。
413+
414+
4. OP实现的性能实践
415+
调用了eigen的broadcast, chop等操作,性能会比手写cuda kernel差几倍以上。此时cpu的实现可以复用eigen,gpu实现可以实现cuda kernel.
416+
417+
386418
#### OP InferShape检查提示信息特别说明
387419
388420
- 检查输入输出变量,请统一遵循以下格式

0 commit comments

Comments
 (0)