Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit fb6f150

Browse files
committed
clean up eager compilation code a bit
1 parent 1c583f1 commit fb6f150

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

functorch/_src/eager_compilation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ def draw_joint_graph(graph, joint_inputs, file_name="full_graph.png"):
171171
draw_graph(graph, file_name)
172172
return default_partition(graph, joint_inputs)
173173

174+
def normalize_as_list(x):
175+
if isinstance(x, tuple):
176+
return list(x)
177+
elif isinstance(x, list):
178+
return x
179+
return [x]
180+
174181
def create_compiled_function(flat_fn, fw_compiler, bw_compiler, partition_fn):
175182
joint_forward_backward = create_joint_forward_backward(flat_fn)
176183

@@ -196,16 +203,11 @@ def forward(ctx, *flat_args):
196203
# print(fw_module.code, bw_module.code)
197204

198205
compiled_fw = fw_compiler(fw_module, flat_args)
199-
fw_outs = compiled_fw(*flat_args)
200-
201-
if not isinstance(fw_outs, list):
202-
fw_outs = [fw_outs]
206+
fw_outs = normalize_as_list(compiled_fw(*flat_args))
203207

204208
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
205209
compiled_bw = bw_compiler(bw_module, bw_args)
206-
fw_outs = compiled_fw(*flat_args)
207-
if not isinstance(fw_outs, list):
208-
fw_outs = [fw_outs]
210+
fw_outs = normalize_as_list(compiled_fw(*flat_args))
209211
ctx.save_for_backward(*fw_outs[num_outs:])
210212
if num_outs == 1:
211213
return fw_outs[0]
@@ -215,9 +217,7 @@ def forward(ctx, *flat_args):
215217
def backward(ctx, *flat_args):
216218
# hmm... this doesn't feel right. todo
217219
contiguous_args = [t.contiguous() for t in flat_args]
218-
out = compiled_bw(*ctx.saved_tensors, *contiguous_args)
219-
if not isinstance(out, list):
220-
out = [out]
220+
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
221221
out_iter = iter(out)
222222
grad_out = [next(out_iter) if p else None for p in ctx.needs_input_grad]
223223
return tuple(grad_out)

0 commit comments

Comments
 (0)