1111from typing import Tuple
1212
1313import jax .numpy as jnp
14+ from jax .lax import scan
1415from jax .scipy .special import logsumexp
1516from jax .tree_util import tree_map
16- from jax .lax import scan
1717
1818import brainpy .math as bm
1919from brainpy .types import ArrayType
@@ -106,7 +106,7 @@ def _cel(_pred, _tar):
106106 loss = logsumexp (bm .as_jax (_pred ), axis = - 1 ) - (_pred * _tar ).sum (axis = - 1 )
107107 return _reduce (outputs = loss , reduction = reduction )
108108
109- r = tree_map (_cel , predicts , targets , is_leaf = lambda x : isinstance ( x , bm . Array ) )
109+ r = tree_map (_cel , predicts , targets , is_leaf = _is_leaf )
110110 return _multi_return (r )
111111
112112
@@ -128,7 +128,7 @@ def crs(_prd, _tar):
128128 logits = jnp .take_along_axis (_prd , _tar , - 1 ).squeeze (- 1 )
129129 return logsumexp (bm .as_jax (_prd ), axis = - 1 ) - logits
130130
131- r = tree_map (crs , predicts , targets , is_leaf = lambda x : isinstance ( x , bm . Array ) )
131+ r = tree_map (crs , predicts , targets , is_leaf = _is_leaf )
132132 return _multi_return (r )
133133
134134
@@ -142,9 +142,14 @@ def cross_entropy_sigmoid(predicts, targets):
142142 Returns:
143143 (batch, ...) tensor of the cross-entropies for each entry.
144144 """
145- r = tree_map (lambda pred , tar : jnp .maximum (pred , 0 ) - pred * tar + jnp .log (1 + jnp .exp (- jnp .abs (pred ))),
146- predicts ,
147- targets )
145+ r = tree_map (
146+ lambda pred , tar : bm .as_jax (
147+ bm .maximum (pred , 0 ) - pred * tar + bm .log (1 + bm .exp (- bm .abs (pred )))
148+ ),
149+ predicts ,
150+ targets ,
151+ is_leaf = _is_leaf
152+ )
148153 return _multi_return (r )
149154
150155
@@ -201,7 +206,7 @@ def loss(pred, tar):
201206 norm = jnp .linalg .norm (bm .as_jax (diff ), ord = 1 , axis = 1 , keepdims = False )
202207 return _reduce (outputs = norm , reduction = reduction )
203208
204- r = tree_map (loss , logits , targets , is_leaf = lambda x : isinstance ( x , bm . Array ) )
209+ r = tree_map (loss , logits , targets , is_leaf = _is_leaf )
205210 return _multi_return (r )
206211
207212
@@ -228,7 +233,9 @@ def l2_loss(predicts, targets):
228233 ----------
229234 .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning.
230235 """
231- r = tree_map (lambda pred , tar : 0.5 * (pred - tar ) ** 2 , predicts , targets )
236+ r = tree_map (lambda pred , tar : 0.5 * (pred - tar ) ** 2 ,
237+ predicts ,
238+ targets )
232239 return _multi_return (r )
233240
234241
@@ -243,7 +250,10 @@ def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'):
243250 Returns:
244251 tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error.
245252 """
246- r = tree_map (lambda a , b : _reduce (jnp .abs (a - b ), reduction = reduction , axis = axis ), x , y )
253+ r = tree_map (lambda a , b : _reduce (bm .abs (a - b ), reduction = reduction , axis = axis ),
254+ x ,
255+ y ,
256+ is_leaf = _is_leaf )
247257 return _multi_return (r )
248258
249259
@@ -260,7 +270,8 @@ def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'):
260270 """
261271 r = tree_map (lambda a , b : _reduce ((a - b ) ** 2 , reduction , axis = axis ),
262272 predicts ,
263- targets )
273+ targets ,
274+ is_leaf = _is_leaf )
264275 return _multi_return (r )
265276
266277
@@ -276,7 +287,9 @@ def mean_squared_log_error(predicts, targets, axis=None, reduction: str = 'mean'
276287 tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
277288 """
278289 r = tree_map (lambda a , b : _reduce ((jnp .log1p (a ) - jnp .log1p (b )) ** 2 , reduction , axis = axis ),
279- predicts , targets , is_leaf = _is_leaf )
290+ predicts ,
291+ targets ,
292+ is_leaf = _is_leaf )
280293 return _multi_return (r )
281294
282295
@@ -309,12 +322,13 @@ def huber_loss(predicts, targets, delta: float = 1.0):
309322 def _loss (y_predict , y_target ):
310323 # 0.5 * err^2 if |err| <= d
311324 # 0.5 * d^2 + d * (|err| - d) if |err| > d
312- diff = jnp .abs (y_predict - y_target )
313- return jnp .where (diff > delta ,
314- delta * (diff - .5 * delta ),
315- 0.5 * diff ** 2 )
325+ diff = bm .abs (y_predict - y_target )
326+ r = bm .where (diff > delta ,
327+ delta * (diff - .5 * delta ),
328+ 0.5 * diff ** 2 )
329+ return bm .as_jax (r )
316330
317- r = tree_map (_loss , targets , predicts )
331+ r = tree_map (_loss , targets , predicts , is_leaf = _is_leaf )
318332 return _multi_return (r )
319333
320334
@@ -382,7 +396,7 @@ def loss(pred, tar):
382396 log_not_p = bm .log_sigmoid (- pred )
383397 return - tar * log_p - (1. - tar ) * log_not_p
384398
385- r = tree_map (loss , logits , labels , is_leaf = lambda x : isinstance ( x , bm . Array ) )
399+ r = tree_map (loss , logits , labels , is_leaf = _is_leaf )
386400 return _multi_return (r )
387401
388402
@@ -433,7 +447,7 @@ def loss(pred, tar):
433447 errors = bm .as_jax (pred - tar )
434448 return jnp .logaddexp (errors , - errors ) - jnp .log (2.0 ).astype (errors .dtype )
435449
436- r = tree_map (loss , predicts , targets , is_leaf = lambda x : isinstance ( x , bm . Array ) )
450+ r = tree_map (loss , predicts , targets , is_leaf = _is_leaf )
437451 return _multi_return (r )
438452
439453
0 commit comments