forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Tensorflow InferShap分析
Qiao Longfei edited this page Sep 12, 2017
·
5 revisions
比较直接的文档在:https://www.tensorflow.org/extend/adding_an_op
op_def_builder中定义了REGISTER_OP所调用的各种方法。 例如
REGISTER_OP("ResourceApplyAdagrad")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
})
.Doc(R"doc(
Update '*var' according to the adagrad scheme.
accum += grad * grad
var -= lr * grad * (1 / sqrt(accum))
var: Should be from a Variable().
accum: Should be from a Variable().
lr: Scaling factor. Must be a scalar.
grad: The gradient.
use_locking: If `True`, updating of the var and accum tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
)doc");构造了一个OpRegistrationData数据结构:
struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
};主要是填写了op_def和shape_inference_fn这两个成员变量