File tree Expand file tree Collapse file tree 2 files changed +11
-3
lines changed
Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -311,9 +311,13 @@ def post_hook(self, out):
311311 post_out = tuple (post_out )
312312 return post_out
313313
314- def forward (self , * args ):
314+ def forward (self , * args , ** kwargs ):
315+ signature = inspect .signature (self ._module .forward )
316+ bound_args = signature .bind (* args , ** kwargs )
317+ args = bound_args .args
315318 arg_list = self .pre_hook (* args )
316319
320+
317321 if self .all_input_no_grad and not self .all_param_no_grad :
318322 placeholder = torch .tensor ([], requires_grad = torch .is_grad_enabled ())
319323 return hook_func .OneStepNoGradFunc .apply (self , placeholder , * arg_list )
Original file line number Diff line number Diff line change @@ -15,11 +15,15 @@ def make_distributed(model: torch.nn.Module):
1515 for kw in list (model ._buffers .keys ()):
1616 if model ._buffers [kw ] is not None :
1717 model ._buffers [kw ] = model ._buffers [kw ].cuda ()
18-
18+ is_module_list = isinstance (model , torch .nn .ModuleList )
19+ pre_module = None
1920 for kw in list (model ._modules .keys ()):
20- if isinstance ( model , torch . nn . ModuleList ) :
21+ if is_module_list :
2122 if not isinstance (model ._modules [kw ], Block ):
2223 model ._modules [kw ] = Block (model_wrapper_dispatch (model ._modules [kw ]))
24+ if pre_module is not None :
25+ model ._modules [kw ].set_pre_module (pre_module )
26+ pre_module = model ._modules [kw ]
2327 else :
2428 model ._modules [kw ] = model_wrapper_dispatch (model ._modules [kw ])
2529
You can’t perform that action at this time.
0 commit comments