@@ -171,6 +171,13 @@ def draw_joint_graph(graph, joint_inputs, file_name="full_graph.png"):
171
171
draw_graph (graph , file_name )
172
172
return default_partition (graph , joint_inputs )
173
173
174
+ def normalize_as_list (x ):
175
+ if isinstance (x , tuple ):
176
+ return list (x )
177
+ elif isinstance (x , list ):
178
+ return x
179
+ return [x ]
180
+
174
181
def create_compiled_function (flat_fn , fw_compiler , bw_compiler , partition_fn ):
175
182
joint_forward_backward = create_joint_forward_backward (flat_fn )
176
183
@@ -196,16 +203,11 @@ def forward(ctx, *flat_args):
196
203
# print(fw_module.code, bw_module.code)
197
204
198
205
compiled_fw = fw_compiler (fw_module , flat_args )
199
- fw_outs = compiled_fw (* flat_args )
200
-
201
- if not isinstance (fw_outs , list ):
202
- fw_outs = [fw_outs ]
206
+ fw_outs = normalize_as_list (compiled_fw (* flat_args ))
203
207
204
208
bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
205
209
compiled_bw = bw_compiler (bw_module , bw_args )
206
- fw_outs = compiled_fw (* flat_args )
207
- if not isinstance (fw_outs , list ):
208
- fw_outs = [fw_outs ]
210
+ fw_outs = normalize_as_list (compiled_fw (* flat_args ))
209
211
ctx .save_for_backward (* fw_outs [num_outs :])
210
212
if num_outs == 1 :
211
213
return fw_outs [0 ]
@@ -215,9 +217,7 @@ def forward(ctx, *flat_args):
215
217
def backward (ctx , * flat_args ):
216
218
# hmm... this doesn't feel right. todo
217
219
contiguous_args = [t .contiguous () for t in flat_args ]
218
- out = compiled_bw (* ctx .saved_tensors , * contiguous_args )
219
- if not isinstance (out , list ):
220
- out = [out ]
220
+ out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
221
221
out_iter = iter (out )
222
222
grad_out = [next (out_iter ) if p else None for p in ctx .needs_input_grad ]
223
223
return tuple (grad_out )
0 commit comments