File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def _create_differentiable(inps, level=None):
30
30
def create_differentiable (x ):
31
31
if isinstance (x , torch .Tensor ):
32
32
return x .requires_grad_ ()
33
- raise ValueError (f'Thing passed to transform API must be Tensor,'
33
+ raise ValueError (f'Thing passed to transform API must be Tensor, '
34
34
f'got { type (x )} ' )
35
35
return tree_map (create_differentiable , inps )
36
36
@@ -296,12 +296,12 @@ def wrapper(*args, **kwargs):
296
296
output , aux = output
297
297
298
298
if not isinstance (output , torch .Tensor ):
299
- raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args)'
299
+ raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args) '
300
300
f'to return a Tensor, got { type (output )} ' )
301
301
if output .dim () != 0 :
302
- raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args)'
302
+ raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args) '
303
303
'to return a scalar Tensor, got tensor with '
304
- f'{ output .dim ()} dims. Maybe you wanted to'
304
+ f'{ output .dim ()} dims. Maybe you wanted to '
305
305
'use the vjp or jacrev APIs instead?' )
306
306
307
307
flat_diff_args , spec = tree_flatten (diff_args )
You can’t perform that action at this time.
0 commit comments