Skip to content
This repository was archived by the owner on Aug 11, 2025. It is now read-only.

Commit d3a3c3b

Browse files
cesarsouzamigueldeicaza
authored andcommitted
Adding support for tf.sigmoid_cross_entropy_with_logits and tf.where (#128)
* GH-127: Add support for tf.sigmoid_cross_entropy_with_logits and tf.where * Adding a note stating that part of the original TF implementation has been left behind since it wasn't needed for TFSharp.
1 parent 60299b7 commit d3a3c3b

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

TensorFlowSharp/OperationsExtras.cs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,77 @@ public TFOutput ClipByAverageNorm (TFOutput x, TFOutput clip_norm, string operNa
488488
}
489489
}
490490

491+
/// <summary>
492+
/// Computes sigmoid cross entropy given `logits`.
493+
/// </summary>
494+
///
495+
/// <remarks>
496+
/// Measures the probability error in discrete classification tasks in which each
497+
/// class is independent and not mutually exclusive.For instance, one could
498+
/// perform multilabel classification where a picture can contain both an elephant
499+
/// and a dog at the same time.
500+
/// </remarks>
501+
///
502+
public TFOutput SigmoidCrossEntropyWithLogits (TFOutput labels, TFOutput logits, string operName = null)
503+
{
504+
// https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/ops/nn_impl.py#L100
505+
506+
var scopeName = this.MakeName ("logistic_loss", operName);
507+
using (var newScope = this.WithScope (scopeName)) {
508+
// Note: The following lines have not been ported from the original TF implementation since
509+
// TensorFlowSharp API should guarantee that logits and labels are of type TFOutput by design:
510+
//
511+
// logits = ops.convert_to_tensor(logits, name: "logits");
512+
// labels = ops.convert_to_tensor(labels, name: "labels");
513+
// try
514+
// {
515+
// labels.get_shape().merge_with(logits.get_shape())
516+
// }
517+
// catch
518+
// {
519+
// throw new ArgumentException("logits and labels must have the same shape ({logits.get_shape()} vs {labels.get_shape()})");
520+
// }
521+
522+
// The logistic loss formula from above is
523+
// x - x * z + log(1 + exp(-x))
524+
// For x < 0, a more numerically stable formula is
525+
// -x * z + log(1 + exp(x))
526+
// Note that these two expressions can be combined into the following:
527+
// max(x, 0) - x * z + log(1 + exp(-abs(x)))
528+
// To allow computing gradients at zero, we define custom versions of max and
529+
// abs functions.
530+
TFOutput zeros = this.ZerosLike (logits);
531+
TFOutput cond = this.GreaterEqual (logits, zeros);
532+
TFOutput relu_logits = this.Where (cond, logits, zeros);
533+
TFOutput neg_abs_logits = this.Where (cond, this.Neg (logits), logits);
534+
return this.Add (
535+
this.Sub (relu_logits, this.Mul (logits, labels)),
536+
this.Log1p (this.Exp (neg_abs_logits)),
537+
operName: operName);
538+
}
539+
}
540+
541+
/// <summary>
542+
/// Return elements from x or y depending on condition.
543+
/// </summary>
544+
///
545+
/// <param name="condition">LabeledTensor of type `bool`.</param>
546+
/// <param name="x">LabeledTensor for values where condition is true.</param>
547+
/// <param name="y">LabeledTensor for values where condition is false.</param>
548+
/// <param name="name">Optional op name.</param>
549+
///
550+
/// <returns>The labeled tensor with values according to condition.</returns>
551+
///
552+
public TFOutput Where (TFOutput condition, TFOutput? x, TFOutput? y, string name = null)
553+
{
554+
// https://github.com/tensorflow/tensorflow/blob/d4ce3b4681b3a550c095b2cd18a79494d1cc4039/tensorflow/python/ops/array_ops.py#L2342
555+
if (x == null && y == null)
556+
return this.Where (input: condition, operName: name);
557+
else if (x != null && y != null)
558+
return this.Select (condition: condition, t: x.Value, e: y.Value, operName: name);
559+
throw new ArgumentException ("x and y must both be non-None or both be None.");
560+
}
561+
491562
/// <summary>
492563
/// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
493564
/// </summary>

tests/TensorFlowSharp.Tests.CSharp/MathTests.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,35 @@ public void Should_ReduceMean (double [,] input, int? axis, object expected)
7171
}
7272
}
7373
}
74+
75+
private static IEnumerable<object []> sigmoidCrossEntropyData ()
76+
{
77+
yield return new object [] { new [] { 1.0, 0.0, 1.0, 1.0 }, new [] { 1.0, 0.0, 1.0, 1.0 }, new [] { 0.31326168751822281, 0.69314718055994529, 0.31326168751822281, 0.31326168751822281 } };
78+
yield return new object [] { new [] { 1.0, 0.0, 1.0, 1.0 }, new [] { -0.2, 4.2, 0.0, 0.0 }, new [] { 0.79813886938159184, 4.2148842546719187, 0.69314718055994529, 0.69314718055994529 } };
79+
yield return new object [] { new [] { 1.0, 0.0 }, new [] { -2.1, -2, -4, 3.0 }, null };
80+
}
81+
82+
[Theory]
83+
[MemberData (nameof (sigmoidCrossEntropyData))]
84+
public void Should_SigmoidCrossEntropyWithLogits (double [] labels, double [] logits, double [] expected)
85+
{
86+
using (var graph = new TFGraph ())
87+
using (var session = new TFSession (graph)) {
88+
var tlabels = graph.Placeholder (TFDataType.Double, new TFShape (2, 2));
89+
var tlogits = graph.Placeholder (TFDataType.Double, new TFShape (2, 2));
90+
91+
TFOutput y = graph.SigmoidCrossEntropyWithLogits (tlabels, tlogits);
92+
93+
if (expected != null) {
94+
TFTensor [] result = session.Run (new [] { tlabels, tlogits }, new TFTensor [] { labels, logits }, new [] { y });
95+
96+
double [] actual = (double [])result [0].GetValue ();
97+
TestUtils.MatrixEqual (expected, actual, precision: 8);
98+
} else {
99+
Assert.Throws<TFException> (() => session.Run (new [] { tlabels, tlogits }, new TFTensor [] { labels, logits }, new [] { y }));
100+
}
101+
}
102+
}
103+
74104
}
75105
}

0 commit comments

Comments
 (0)