@@ -128,11 +128,10 @@ def forward(ctx, *flat_args):
128
128
129
129
bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
130
130
compiled_bw = bw_compiler (bw_module , bw_args )
131
-
132
131
fw_outs = compiled_fw (* flat_args )
133
132
if not isinstance (fw_outs , list ):
134
133
fw_outs = [fw_outs ]
135
- ctx .activations = fw_outs [num_outs :]
134
+ ctx .save_for_backward ( * fw_outs [num_outs :])
136
135
if num_outs == 1 :
137
136
return fw_outs [0 ]
138
137
return tuple (fw_outs [0 :num_outs ])
@@ -141,24 +140,28 @@ def forward(ctx, *flat_args):
141
140
def backward (ctx , * flat_args ):
142
141
# hmm... this doesn't feel right. todo
143
142
contiguous_args = [t .contiguous () for t in flat_args ]
144
- out = compiled_bw (* ctx .activations , * contiguous_args )
143
+ out = compiled_bw (* ctx .saved_tensors , * contiguous_args )
145
144
if not isinstance (out , list ):
146
145
out = [out ]
147
146
out_iter = iter (out )
148
147
grad_out = [next (out_iter ) if p else None for p in ctx .needs_input_grad ]
149
148
return tuple (grad_out )
150
-
149
+
151
150
return CompiledFunction
152
151
153
152
153
+ # using this reduces the overhead by about 50%
154
+ # import tree
154
155
def compiled_function (fn , fw_compiler , bw_compiler , partition_fn = default_partition ):
155
156
saved_fn = None
156
157
157
158
def returned_function (* args , ** kwargs ):
158
159
nonlocal saved_fn
159
- flattened_args , args_spec = pytree .tree_flatten ((args , kwargs ))
160
+ # flattened_args = tree.flatten((args, kwargs))
161
+ flattened_args , _ = pytree .tree_flatten ((args , kwargs ))
160
162
161
163
if saved_fn is None :
164
+ flattened_args , args_spec = pytree .tree_flatten ((args , kwargs ))
162
165
def flat_fn (* args ):
163
166
args , kwargs = pytree .tree_unflatten (args , args_spec )
164
167
return fn (* args , ** kwargs )
0 commit comments