-
I'm trying to modify MyFunc (see below) so that I dont get zero gradient at all input values. The FAQ states:
So I thought if I modify - my original code:
with:
then I would finally not get zero gradients but this doesn't work either. Could someone help please?
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I don't quite follow your full code example, but I think the way to do a "soft" version of your snippet: if fPoints >= dPoints:
balance += fOdds*1000 would be something like this: import jax.nn
balance += fOdds * 1000 * jax.nn.sigmoid(fPoints - dPoints) This uses the Sigmoid function in place of your (implied) step function, so that if The recommendation for Hope that helps! |
Beta Was this translation helpful? Give feedback.
I don't quite follow your full code example, but I think the way to do a "soft" version of your snippet:
would be something like this:
This uses the Sigmoid function in place of your (implied) step function, so that if
fPoints
is much smaller thandPoints
, the resulting operation isbalance += 0
, and iffPoints
is much larger thandPoints
, the resulting operation isbalance += fOdds * 1000
, and the output varies smoothly between these two extremes.The recommendation for
softmax
in the docs you linked to is as a replacement forargmax
, which your code doesn't use, soso…