|
12 | 12 |
|
13 | 13 | # pylint: disable=unused-argument,missing-docstring
|
14 | 14 |
|
| 15 | + |
| 16 | +def softmax_cross_entropy_with_logits_computation(ctx, label, logit, tf_ori_node): |
| 17 | + label_dtype = ctx.get_dtype(label.output[0]) |
| 18 | + logit_dtype = ctx.get_dtype(logit.output[0]) |
| 19 | + utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit") |
| 20 | + |
| 21 | + log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output) |
| 22 | + # implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1)) |
| 23 | + mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]]) |
| 24 | + reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]}) |
| 25 | + const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"), |
| 26 | + np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype])) |
| 27 | + mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]]) |
| 28 | + shapes = tf_ori_node.output_shapes |
| 29 | + dtypes = tf_ori_node.output_dtypes |
| 30 | + ctx.remove_node(tf_ori_node.name) |
| 31 | + res = ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], attr={"axes": [1]}, |
| 32 | + outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]]) |
| 33 | + |
| 34 | + |
| 35 | +def softmax_cross_entropy_with_logits_op(ctx, node, name, args): |
| 36 | + logits = node.inputs[0] |
| 37 | + logit_dtype = ctx.get_dtype(logits.output[0]) |
| 38 | + labels = node.inputs[1] |
| 39 | + label_dtype = ctx.get_dtype(labels.output[0]) |
| 40 | + if label_dtype != logit_dtype: |
| 41 | + labels = ctx.make_node("Cast", labels.output, attr={"to": logit_dtype}, dtypes=[logit_dtype]) |
| 42 | + |
| 43 | + softmax_cross_entropy_with_logits_computation(ctx, labels, logits, node) |
| 44 | + |
| 45 | + |
15 | 46 | def sparse_softmax_cross_entropy_with_logits_op(ctx, node, name, args):
|
16 | 47 | # make subgraph to implement one_hot, idea comes from onehot_op
|
17 | 48 | indices_name = node.input[1]
|
@@ -92,3 +123,32 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
|
92 | 123 | ctx.make_node(op_type="Squeeze",
|
93 | 124 | inputs=[mul2.output[0]], outputs=[node.output[0]],
|
94 | 125 | attr={"axes": [1]}, shapes=[shapes[0]], dtypes=[dtypes[0]])
|
| 126 | + |
| 127 | + |
| 128 | +def sparse_softmax_cross_entropy_with_logits_op9(ctx, node, name, args): |
| 129 | + # float32/64 output = SparseSoftmaxCrossEntropyWithLogits(float32/64 features, int32/64 labels) |
| 130 | + # the detail math process of this op is: a = onehot(labels), b = logsoftmax(features), reduce_sum(mul(a, b)) |
| 131 | + logit_node = node.inputs[0] |
| 132 | + logit_shape = ctx.get_shape(node.input[0]) |
| 133 | + logit_dtype = ctx.get_dtype(node.input[0]) |
| 134 | + |
| 135 | + label_name = node.input[1] |
| 136 | + label_dtype = ctx.get_dtype(label_name) |
| 137 | + |
| 138 | + num_class = logit_shape[-1] |
| 139 | + utils.make_sure(num_class != -1, "number of class should be known, otherwise subgraph to get the info is needed") |
| 140 | + # int64 is used because of onnxruntime "onehot" only supports this dtype |
| 141 | + depth_node = ctx.make_const(utils.make_name("onehot_depth"), np.array([num_class]).astype(np.int64)) |
| 142 | + values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64)) |
| 143 | + if label_dtype != TensorProto.INT64: |
| 144 | + onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0] |
| 145 | + else: |
| 146 | + onehot_indice = label_name |
| 147 | + label_node = ctx.make_node(op_type="OneHot", inputs=[onehot_indice, depth_node.output[0], values_node.output[0]]) |
| 148 | + # the above logic makes output dtype of label_node now always int64 |
| 149 | + # make sure label has same dtype as logit |
| 150 | + if logit_dtype != TensorProto.INT64: |
| 151 | + label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype]) |
| 152 | + |
| 153 | + softmax_cross_entropy_with_logits_computation(ctx, label_node, logit_node, node) |
| 154 | + |
0 commit comments