Skip to content

Commit a00a584

Browse files
committed
Make block can accept kwargs
1 parent 3d35a26 commit a00a584

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

bmtrain/block_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,13 @@ def post_hook(self, out):
311311
post_out = tuple(post_out)
312312
return post_out
313313

314-
def forward(self, *args):
314+
def forward(self, *args, **kwargs):
315+
signature = inspect.signature(self._module.forward)
316+
bound_args = signature.bind(*args, **kwargs)
317+
args = bound_args.args
315318
arg_list = self.pre_hook(*args)
316319

320+
317321
if self.all_input_no_grad and not self.all_param_no_grad:
318322
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
319323
return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list)

bmtrain/wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ def make_distributed(model: torch.nn.Module):
1515
for kw in list(model._buffers.keys()):
1616
if model._buffers[kw] is not None:
1717
model._buffers[kw] = model._buffers[kw].cuda()
18-
18+
is_module_list = isinstance(model, torch.nn.ModuleList)
19+
pre_module = None
1920
for kw in list(model._modules.keys()):
20-
if isinstance(model, torch.nn.ModuleList):
21+
if is_module_list:
2122
if not isinstance(model._modules[kw], Block):
2223
model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw]))
24+
if pre_module is not None:
25+
model._modules[kw].set_pre_module(pre_module)
26+
pre_module = model._modules[kw]
2327
else:
2428
model._modules[kw] = model_wrapper_dispatch(model._modules[kw])
2529

0 commit comments

Comments
 (0)