Skip to content

Tensorflow InferShap分析

Qiao Longfei edited this page Sep 12, 2017 · 5 revisions

比较直接的文档在:https://www.tensorflow.org/extend/adding_an_op

op_def_builder中定义了REGISTER_OP所调用的各种方法。 例如

REGISTER_OP("ResourceApplyAdagrad")
    .Input("var: resource")
    .Input("accum: resource")
    .Input("lr: T")
    .Input("grad: T")
    .Attr("T: numbertype")
    .Attr("use_locking: bool = false")
    .SetShapeFn([](InferenceContext* c) {
      return ApplyAdagradShapeFn(c, false /* sparse */);
    })
    .Doc(R"doc(
Update '*var' according to the adagrad scheme.

accum += grad * grad
var -= lr * grad * (1 / sqrt(accum))

var: Should be from a Variable().
accum: Should be from a Variable().
lr: Scaling factor. Must be a scalar.
grad: The gradient.
use_locking: If `True`, updating of the var and accum tensors will be protected
  by a lock; otherwise the behavior is undefined, but may exhibit less
  contention.
)doc");

构造了一个OpRegistrationData数据结构:

struct OpRegistrationData {
 public:
  OpRegistrationData() {}
  OpRegistrationData(const OpDef& def) : op_def(def) {}

  OpDef op_def;
  OpShapeInferenceFn shape_inference_fn;
};

主要是填写了op_def和shape_inference_fn这两个成员变量

Clone this wiki locally