File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -151,14 +151,16 @@ def grad_function(fn):
151
151
152
152
153
153
def gradient_execution_order (
154
- roots : Sequence [TensorOrEdge ], with_respect_to : Sequence [TensorOrEdge ]
154
+ roots : Sequence [TensorOrEdge ], with_respect_to : Sequence [Any ]
155
155
) -> List [int ]:
156
156
"""
157
157
Returns the order in which the gradients for `with_respect_to` would become available
158
158
if autograd were run on `roots`. This is the reverse order of each tensors
159
159
first use in the gradient computation.
160
160
"""
161
- with_respect_to = [_gradient_edge (g ) for g in with_respect_to ]
161
+ with_respect_to = [
162
+ (g .node , g .output_nr ) for g in map (_gradient_edge , with_respect_to )
163
+ ]
162
164
min_sequence_nr : Dict [Any , float ] = {e : math .inf for e in with_respect_to }
163
165
164
166
to_scan = [_gradient_edge (r ).node for r in roots ]
You can’t perform that action at this time.
0 commit comments