Skip to content

Commit 9140185

Browse files
authored
Merge pull request #207 from OpenBMB/dev
bmt.Block now can accept kwargs in forward function
2 parents d7bb04c + 2497721 commit 9140185

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

bmtrain/block_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,13 @@ def post_hook(self, out):
311311
post_out = tuple(post_out)
312312
return post_out
313313

314-
def forward(self, *args):
314+
def forward(self, *args, **kwargs):
315+
signature = inspect.signature(self._module.forward)
316+
bound_args = signature.bind(*args, **kwargs)
317+
args = bound_args.args
315318
arg_list = self.pre_hook(*args)
316319

320+
317321
if self.all_input_no_grad and not self.all_param_no_grad:
318322
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
319323
return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list)

bmtrain/wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ def make_distributed(model: torch.nn.Module):
1515
for kw in list(model._buffers.keys()):
1616
if model._buffers[kw] is not None:
1717
model._buffers[kw] = model._buffers[kw].cuda()
18-
18+
is_module_list = isinstance(model, torch.nn.ModuleList)
19+
pre_module = None
1920
for kw in list(model._modules.keys()):
20-
if isinstance(model, torch.nn.ModuleList):
21+
if is_module_list:
2122
if not isinstance(model._modules[kw], Block):
2223
model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw]))
24+
if pre_module is not None:
25+
model._modules[kw].set_pre_module(pre_module)
26+
pre_module = model._modules[kw]
2327
else:
2428
model._modules[kw] = model_wrapper_dispatch(model._modules[kw])
2529

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def build_extension(self, ext):
9393
]
9494
setup(
9595
name='bmtrain',
96-
version='1.0.0',
96+
version='1.0.1',
9797
author="Guoyang Zeng",
9898
author_email="[email protected]",
9999
description="A toolkit for training big models",

0 commit comments

Comments
 (0)