File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -319,11 +319,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
319319 if has_non_batch_dim :
320320 # There is one non-batched dimension, it is the first one
321321 non_batch_dim_len = output .shape [0 ]
322- jac_output_shape = [output .shape [0 ]] + list (output .shape )
323-
324- jac_output = torch .zeros (jac_output_shape , device = output .device , dtype = output .dtype )
325- for i in range (non_batch_dim_len ):
326- jac_output [i , i , ...] = 1.0
322+ identity_matrix = torch .eye (non_batch_dim_len , device = output .device , dtype = output .dtype )
323+ ones = torch .ones_like (output [0 ])
324+ jac_output = torch .einsum ("ij, ... -> ij..." , identity_matrix , ones )
327325
328326 _ = vmap (differentiation )(jac_output )
329327 else :
You can’t perform that action at this time.
0 commit comments