Skip to content

Commit 9cc9755

Browse files
authored
Merge pull request #92 from Achazwl/feat-multi-return
support multiple input-output in transformerblocklist
2 parents d531727 + 7084623 commit 9cc9755

File tree

3 files changed

+180
-29
lines changed

3 files changed

+180
-29
lines changed

bmtrain/block_layer.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -680,29 +680,30 @@ def __repr__(self):
680680

681681
class OpTransformerBlockList(torch.autograd.Function):
682682
@staticmethod
683-
def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_state, *args):
683+
def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args):
684684
tensors = []
685685
others = []
686-
for arg in args:
686+
for arg in args[num_hidden:]:
687687
if torch.is_tensor(arg):
688688
tensors.append(arg)
689689
others.append(None)
690690
else:
691691
tensors.append(None)
692692
others.append(arg)
693+
hidden_states = args[:num_hidden]
693694

694695
ctx.nontensor_inputs = others
695696
ctx.self = self
696697
ctx.save_list = copy.deepcopy(save_list)
697698
ctx.num_save_needed = save_list[-1][1]+1
698-
ctx.layers_dict=[{} for _ in range(len(self))]
699+
ctx.layers_dict = [{} for _ in range(len(self))]
699700
layer_inputs = []
700701
layer_inspector = []
701702
cuda_rng_state = []
702703
for i in range(len(self)):
703704
with torch.no_grad():
704705
if save_list[i][0] == i:
705-
layer_inputs.append(hidden_state.detach())
706+
layer_inputs += [hidden_state.detach() for hidden_state in hidden_states]
706707
cuda_rng_state.append( torch.cuda.get_rng_state() )
707708
if config['zero_level']==2:
708709
flag = 1
@@ -713,29 +714,38 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
713714
block_ctx.enter()
714715
# call inner module directly
715716
with ScopedTensorInspectorContext() as inspector:
716-
hidden_state = self._modules[str(i)]._module._call_impl(hidden_state, *args)
717+
hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:])
718+
if not isinstance(hidden_states, tuple):
719+
hidden_states = (hidden_states,)
717720
block_ctx.exit()
718721
for it in inspector.hidden_states:
719722
debug.append("_inspect_hidden_states", it)
720723
layer_inspector.append(inspector.hidden_states)
721724

722725
ctx.layer_inspector = layer_inspector
723726
ctx.cuda_rng_state = cuda_rng_state
727+
ctx.num_hidden = num_hidden
724728

725729
ctx.save_for_backward(*layer_inputs, *tensors)
726730

727731
if self.return_hidden_states:
728732
middle_hiddens = layer_inputs
729733
for mid in middle_hiddens:
730734
mid.requires_grad_()
731-
middle_hiddens = torch.stack(middle_hiddens, dim=0)
735+
middle_hiddens = [
736+
torch.stack(middle_hiddens[i::num_hidden], dim=0)
737+
for i in range(num_hidden)
738+
]
732739
else:
733-
middle_hiddens = None
734-
return tuple([hidden_state, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])
740+
middle_hiddens = [None] * num_hidden
741+
return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])
735742

736743

737744
@staticmethod
738-
def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle: List[torch.Tensor], *grad_inspectors):
745+
def backward(ctx, *grads):
746+
grad_hidden_states = grads[:ctx.num_hidden]
747+
grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden]
748+
grad_inspectors = grads[2*ctx.num_hidden:]
739749
def exit_prev(prev_ctx, prev_grad):
740750
if prev_ctx is not None:
741751
if prev_grad:
@@ -755,8 +765,8 @@ def exit_prev(prev_ctx, prev_grad):
755765
all_inputs = []
756766
input_requires_grad = []
757767

758-
layer_inputs = ctx.saved_tensors[:ctx.num_save_needed]
759-
save_args = ctx.saved_tensors[ctx.num_save_needed:]
768+
layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden]
769+
save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:]
760770
for tensor, other in zip(save_args, ctx.nontensor_inputs):
761771
if tensor is None:
762772
all_inputs.append(other)
@@ -786,14 +796,23 @@ def exit_prev(prev_ctx, prev_grad):
786796
block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag)
787797
block_ctx.enter()
788798
exit_prev(prev_ctx, prev_grad)
789-
output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs)
799+
outputs = ctx.self._modules[str(j)]._module._call_impl(
800+
layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden],
801+
*all_inputs
802+
)
803+
if not isinstance(outputs, tuple):
804+
outputs = (outputs,)
790805
prev_ctx = block_ctx
791806
prev_grad = False
792-
layer_inputs[ctx.save_list[j+1][1]].copy_(output)
807+
for k, output in enumerate(outputs):
808+
layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output)
793809
ctx.save_list[j+1][0] = j+1
794810

795811
torch.cuda.set_rng_state(ctx.cuda_rng_state[i])
796-
ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_()
812+
ipts = [
813+
layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_()
814+
for k in range(ctx.num_hidden)
815+
]
797816
if config['zero_level'] == 2:
798817
flag = 2
799818
else:
@@ -805,7 +824,9 @@ def exit_prev(prev_ctx, prev_grad):
805824
prev_grad = True
806825

807826
with ScopedTensorInspectorContext() as inspector:
808-
output = ctx.self._modules[str(i)]._module._call_impl(ipt, *all_inputs)
827+
outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs)
828+
if not isinstance(outputs, tuple):
829+
outputs = (outputs,)
809830

810831
assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed"
811832
for j, it in enumerate(inspector.hidden_states):
@@ -818,18 +839,20 @@ def exit_prev(prev_ctx, prev_grad):
818839
ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"]
819840
if len(inspector.hidden_states) > 0:
820841
torch.autograd.backward(
821-
[output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
822-
(grad_hidden_state,) + grad_inspectors[-len(inspector.hidden_states):],
842+
list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
843+
grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):],
823844
)
824845
grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)]
825846
else:
826847
torch.autograd.backward(
827-
[output],
828-
(grad_hidden_state,),
848+
outputs,
849+
grad_hidden_states,
829850
)
830-
grad_hidden_state = ipt.grad
831-
if grad_middle is not None:
832-
grad_hidden_state = grad_hidden_state + grad_middle[i]
851+
grad_hidden_states = [ipt.grad for ipt in ipts]
852+
for k in range(ctx.num_hidden):
853+
if grad_middles[k] is not None:
854+
grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i]
855+
grad_hidden_states = tuple(grad_hidden_states)
833856

834857
exit_prev(prev_ctx, prev_grad)
835858

@@ -839,7 +862,7 @@ def exit_prev(prev_ctx, prev_grad):
839862
grads.append(inp.grad)
840863
else:
841864
grads.append(None)
842-
return (None, None, None, grad_hidden_state) + tuple(grads)
865+
return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads)
843866

844867
class TransformerBlockList(torch.nn.Module):
845868
r"""
@@ -862,7 +885,7 @@ class TransformerBlockList(torch.nn.Module):
862885
"""
863886
_modules: Dict[str, CheckpointBlock]
864887

865-
def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
888+
def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None:
866889
super().__init__()
867890

868891
self._modules = {}
@@ -872,6 +895,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
872895
self._modules[str(i)] = module
873896
self.add_module(str(i), module)
874897

898+
self.num_hidden = num_hidden
899+
875900
if sqrt:
876901
length = len(self)
877902
num_save_needed = 0
@@ -901,12 +926,11 @@ def __iter__(self) -> Iterator[CheckpointBlock]:
901926
def __getitem__(self, index: Union[int, str]) -> CheckpointBlock:
902927
return self._modules[str(index)]
903928

904-
def forward(self, hidden_state, *args, return_hidden_states = False):
929+
def forward(self, *args, return_hidden_states = False):
905930
self.return_hidden_states = return_hidden_states
906931
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
907-
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args)
908-
last_hidden, middle_hiddens = outputs[:2]
932+
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args)
909933
if return_hidden_states:
910-
return last_hidden, middle_hiddens
934+
return tuple(outputs[:2*self.num_hidden])
911935
else:
912-
return last_hidden
936+
return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0]

tests/test_all.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
("dropout", 1),
1616
("loss_func", 1),
1717

18+
("multi_return", 2),
1819
("middle_hidden", 4),
1920
("other_hidden", 4),
2021
("inspector_hidden", 2),

tests/test_multi_return.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from utils import *
2+
3+
import bmtrain as bmt
4+
import torch
5+
import random
6+
from bmtrain import config
7+
from bmtrain.block_layer import CheckpointBlock, TransformerBlockList
8+
from bmtrain.pipe_layer import PipelineTransformerBlockList
9+
import torch.nn.functional as F
10+
11+
class MultiInputReturn(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
15+
def forward(self, a, b, c, d, e):
16+
return a*2, b+d, c*4+e*5
17+
18+
class Model_ZERO(torch.nn.Module):
19+
def __init__(self, ms) -> None:
20+
super().__init__()
21+
self.ms = TransformerBlockList([
22+
CheckpointBlock(m)
23+
for m in ms
24+
], num_hidden=3)
25+
26+
def forward(self, x):
27+
y = self.ms(*x)
28+
return y
29+
30+
class Model_PIPE(torch.nn.Module):
31+
def __init__(self, ms) -> None:
32+
super().__init__()
33+
self.ms = PipelineTransformerBlockList([
34+
CheckpointBlock(m)
35+
for m in ms
36+
], num_hidden=3)
37+
38+
def forward(self, x):
39+
y = self.ms(*x)
40+
return y
41+
42+
class Model_BLOCK(torch.nn.Module):
43+
def __init__(self, ms) -> None:
44+
super().__init__()
45+
self.ms = torch.nn.ModuleList([
46+
CheckpointBlock(m)
47+
for m in ms
48+
])
49+
50+
def forward(self, x):
51+
y = x[:3]
52+
other = x[3:]
53+
for m in self.ms:
54+
y = m(*y, *other)
55+
return y
56+
57+
class Model_NORMAL(torch.nn.Module):
58+
def __init__(self, ms) -> None:
59+
super().__init__()
60+
self.ms = torch.nn.ModuleList(ms)
61+
62+
def forward(self, x):
63+
y = x[:3]
64+
other = x[3:]
65+
for m in self.ms:
66+
y = m(*y, *other)
67+
return y
68+
69+
def manual_seed(seed=33):
70+
torch.manual_seed(seed)
71+
random.seed(seed)
72+
try:
73+
import numpy as np
74+
np.random.seed(seed)
75+
except ModuleNotFoundError:
76+
pass
77+
78+
def run(name, cls, num_layer=4, dim=4096):
79+
manual_seed()
80+
81+
ms = [MultiInputReturn() for i in range(num_layer)]
82+
83+
inps = (
84+
torch.randn((dim,)).cuda(),
85+
torch.randn((dim,)).cuda(),
86+
torch.randn((dim,)).cuda(),
87+
torch.randn((dim,)).cuda(),
88+
torch.randn((dim,)).cuda(),
89+
)
90+
last_weights = (
91+
torch.randn((dim,)).cuda(),
92+
torch.randn((dim,)).cuda(),
93+
torch.randn((dim,)).cuda(),
94+
)
95+
96+
for inp in inps:
97+
inp.requires_grad_(True)
98+
m = cls(ms)
99+
100+
ret = ""
101+
logits = m(inps)
102+
loss = (logits[0]*last_weights[0] + logits[1]*last_weights[1] + logits[2]*last_weights[2]).sum()
103+
loss.backward()
104+
return list(logits) + [
105+
inp.grad
106+
for inp in inps
107+
]
108+
109+
def test_main():
110+
ret = {}
111+
ret["normal"] = run("normal", Model_NORMAL)
112+
ret["block"] = run("block", Model_BLOCK)
113+
ret["zero"] = run("zero", Model_ZERO)
114+
# ret["pipe"] = run("pipe", Model_PIPE) # TODO pipeline not support multiple input-output yet
115+
for k, r in ret.items():
116+
bmt.print_rank(f"============={k}============")
117+
bmt.print_rank(r)
118+
for r in ret.values():
119+
for r2 in ret.values():
120+
for i in range(len(r)):
121+
assert_lt((r[i]-r2[i]).abs().max(), 1e-5)
122+
123+
if __name__ == "__main__":
124+
bmt.init_distributed(pipe_size=2)
125+
126+
test_main()

0 commit comments

Comments
 (0)