Skip to content

Commit 3db8984

Browse files
albanDfacebook-github-bot
authored andcommitted
Fix mornarch gradient generation order (#853)
Summary: Pull Request resolved: #853 Don't assume that next_functions contains GradientEdge Reviewed By: zdevito Differential Revision: D80180819 fbshipit-source-id: a4e1233049eb29f46b6baa83342c363cb7098e79
1 parent 8c62bec commit 3db8984

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/monarch/gradient_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,16 @@ def grad_function(fn):
151151

152152

153153
def gradient_execution_order(
154-
roots: Sequence[TensorOrEdge], with_respect_to: Sequence[TensorOrEdge]
154+
roots: Sequence[TensorOrEdge], with_respect_to: Sequence[Any]
155155
) -> List[int]:
156156
"""
157157
Returns the order in which the gradients for `with_respect_to` would become available
158158
if autograd were run on `roots`. This is the reverse order of each tensors
159159
first use in the gradient computation.
160160
"""
161-
with_respect_to = [_gradient_edge(g) for g in with_respect_to]
161+
with_respect_to = [
162+
(g.node, g.output_nr) for g in map(_gradient_edge, with_respect_to)
163+
]
162164
min_sequence_nr: Dict[Any, float] = {e: math.inf for e in with_respect_to}
163165

164166
to_scan = [_gradient_edge(r).node for r in roots]

0 commit comments

Comments
 (0)