@@ -680,29 +680,30 @@ def __repr__(self):
680680
681681class OpTransformerBlockList (torch .autograd .Function ):
682682 @staticmethod
683- def forward (ctx , placeholder , self : 'TransformerBlockList' , save_list , hidden_state , * args ):
683+ def forward (ctx , placeholder , self : 'TransformerBlockList' , save_list , num_hidden , * args ):
684684 tensors = []
685685 others = []
686- for arg in args :
686+ for arg in args [ num_hidden :] :
687687 if torch .is_tensor (arg ):
688688 tensors .append (arg )
689689 others .append (None )
690690 else :
691691 tensors .append (None )
692692 others .append (arg )
693+ hidden_states = args [:num_hidden ]
693694
694695 ctx .nontensor_inputs = others
695696 ctx .self = self
696697 ctx .save_list = copy .deepcopy (save_list )
697698 ctx .num_save_needed = save_list [- 1 ][1 ]+ 1
698- ctx .layers_dict = [{} for _ in range (len (self ))]
699+ ctx .layers_dict = [{} for _ in range (len (self ))]
699700 layer_inputs = []
700701 layer_inspector = []
701702 cuda_rng_state = []
702703 for i in range (len (self )):
703704 with torch .no_grad ():
704705 if save_list [i ][0 ] == i :
705- layer_inputs . append ( hidden_state .detach ())
706+ layer_inputs += [ hidden_state .detach () for hidden_state in hidden_states ]
706707 cuda_rng_state .append ( torch .cuda .get_rng_state () )
707708 if config ['zero_level' ]== 2 :
708709 flag = 1
@@ -713,29 +714,38 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
713714 block_ctx .enter ()
714715 # call inner module directly
715716 with ScopedTensorInspectorContext () as inspector :
716- hidden_state = self ._modules [str (i )]._module ._call_impl (hidden_state , * args )
717+ hidden_states = self ._modules [str (i )]._module ._call_impl (* hidden_states , * args [num_hidden :])
718+ if not isinstance (hidden_states , tuple ):
719+ hidden_states = (hidden_states ,)
717720 block_ctx .exit ()
718721 for it in inspector .hidden_states :
719722 debug .append ("_inspect_hidden_states" , it )
720723 layer_inspector .append (inspector .hidden_states )
721724
722725 ctx .layer_inspector = layer_inspector
723726 ctx .cuda_rng_state = cuda_rng_state
727+ ctx .num_hidden = num_hidden
724728
725729 ctx .save_for_backward (* layer_inputs , * tensors )
726730
727731 if self .return_hidden_states :
728732 middle_hiddens = layer_inputs
729733 for mid in middle_hiddens :
730734 mid .requires_grad_ ()
731- middle_hiddens = torch .stack (middle_hiddens , dim = 0 )
735+ middle_hiddens = [
736+ torch .stack (middle_hiddens [i ::num_hidden ], dim = 0 )
737+ for i in range (num_hidden )
738+ ]
732739 else :
733- middle_hiddens = None
734- return tuple ([ hidden_state , middle_hiddens ] + [it ["tensor" ] for inspector_hiddens in ctx .layer_inspector for it in inspector_hiddens ])
740+ middle_hiddens = [ None ] * num_hidden
741+ return tuple (list ( hidden_states ) + middle_hiddens + [it ["tensor" ] for inspector_hiddens in ctx .layer_inspector for it in inspector_hiddens ])
735742
736743
737744 @staticmethod
738- def backward (ctx , grad_hidden_state : torch .Tensor , grad_middle : List [torch .Tensor ], * grad_inspectors ):
745+ def backward (ctx , * grads ):
746+ grad_hidden_states = grads [:ctx .num_hidden ]
747+ grad_middles = grads [ctx .num_hidden :2 * ctx .num_hidden ]
748+ grad_inspectors = grads [2 * ctx .num_hidden :]
739749 def exit_prev (prev_ctx , prev_grad ):
740750 if prev_ctx is not None :
741751 if prev_grad :
@@ -755,8 +765,8 @@ def exit_prev(prev_ctx, prev_grad):
755765 all_inputs = []
756766 input_requires_grad = []
757767
758- layer_inputs = ctx .saved_tensors [:ctx .num_save_needed ]
759- save_args = ctx .saved_tensors [ctx .num_save_needed :]
768+ layer_inputs = ctx .saved_tensors [:ctx .num_save_needed * ctx . num_hidden ]
769+ save_args = ctx .saved_tensors [ctx .num_save_needed * ctx . num_hidden :]
760770 for tensor , other in zip (save_args , ctx .nontensor_inputs ):
761771 if tensor is None :
762772 all_inputs .append (other )
@@ -786,14 +796,23 @@ def exit_prev(prev_ctx, prev_grad):
786796 block_ctx = CheckpointBlockContext (ctx .self ._modules [str (j )], ctx .layers_dict [j ], flag )
787797 block_ctx .enter ()
788798 exit_prev (prev_ctx , prev_grad )
789- output = ctx .self ._modules [str (j )]._module ._call_impl (layer_inputs [ctx .save_list [j ][1 ]], * all_inputs )
799+ outputs = ctx .self ._modules [str (j )]._module ._call_impl (
800+ layer_inputs [ctx .save_list [j ][1 ]* ctx .num_hidden : ctx .save_list [j ][1 ]* ctx .num_hidden + ctx .num_hidden ],
801+ * all_inputs
802+ )
803+ if not isinstance (outputs , tuple ):
804+ outputs = (outputs ,)
790805 prev_ctx = block_ctx
791806 prev_grad = False
792- layer_inputs [ctx .save_list [j + 1 ][1 ]].copy_ (output )
807+ for k , output in enumerate (outputs ):
808+ layer_inputs [ctx .save_list [j + 1 ][1 ]* ctx .num_hidden + k ].copy_ (output )
793809 ctx .save_list [j + 1 ][0 ] = j + 1
794810
795811 torch .cuda .set_rng_state (ctx .cuda_rng_state [i ])
796- ipt = layer_inputs [ctx .save_list [i ][1 ]].detach ().requires_grad_ ()
812+ ipts = [
813+ layer_inputs [ctx .save_list [i ][1 ]* ctx .num_hidden + k ].detach ().requires_grad_ ()
814+ for k in range (ctx .num_hidden )
815+ ]
797816 if config ['zero_level' ] == 2 :
798817 flag = 2
799818 else :
@@ -805,7 +824,9 @@ def exit_prev(prev_ctx, prev_grad):
805824 prev_grad = True
806825
807826 with ScopedTensorInspectorContext () as inspector :
808- output = ctx .self ._modules [str (i )]._module ._call_impl (ipt , * all_inputs )
827+ outputs = ctx .self ._modules [str (i )]._module ._call_impl (* ipts , * all_inputs )
828+ if not isinstance (outputs , tuple ):
829+ outputs = (outputs ,)
809830
810831 assert len (ctx .layer_inspector [i ]) == len (inspector .hidden_states ), "Backward step changed"
811832 for j , it in enumerate (inspector .hidden_states ):
@@ -818,18 +839,20 @@ def exit_prev(prev_ctx, prev_grad):
818839 ctx .layer_inspector [i ][j ]["requires_grad" ] = it ["requires_grad" ]
819840 if len (inspector .hidden_states ) > 0 :
820841 torch .autograd .backward (
821- [ output ] + [hidden_state ["tensor" ] for hidden_state in inspector .hidden_states ],
822- ( grad_hidden_state ,) + grad_inspectors [- len (inspector .hidden_states ):],
842+ list ( outputs ) + [hidden_state ["tensor" ] for hidden_state in inspector .hidden_states ],
843+ grad_hidden_states + grad_inspectors [- len (inspector .hidden_states ):],
823844 )
824845 grad_inspectors = grad_inspectors [:- len (inspector .hidden_states )]
825846 else :
826847 torch .autograd .backward (
827- [ output ] ,
828- ( grad_hidden_state ,) ,
848+ outputs ,
849+ grad_hidden_states ,
829850 )
830- grad_hidden_state = ipt .grad
831- if grad_middle is not None :
832- grad_hidden_state = grad_hidden_state + grad_middle [i ]
851+ grad_hidden_states = [ipt .grad for ipt in ipts ]
852+ for k in range (ctx .num_hidden ):
853+ if grad_middles [k ] is not None :
854+ grad_hidden_states [k ] = grad_hidden_states [k ] + grad_middles [k ][i ]
855+ grad_hidden_states = tuple (grad_hidden_states )
833856
834857 exit_prev (prev_ctx , prev_grad )
835858
@@ -839,7 +862,7 @@ def exit_prev(prev_ctx, prev_grad):
839862 grads .append (inp .grad )
840863 else :
841864 grads .append (None )
842- return (None , None , None , grad_hidden_state ) + tuple (grads )
865+ return (None , None , None , None ) + tuple ( grad_hidden_states ) + tuple (grads )
843866
844867class TransformerBlockList (torch .nn .Module ):
845868 r"""
@@ -862,7 +885,7 @@ class TransformerBlockList(torch.nn.Module):
862885 """
863886 _modules : Dict [str , CheckpointBlock ]
864887
865- def __init__ (self , modules : Iterable [CheckpointBlock ], sqrt = False ) -> None :
888+ def __init__ (self , modules : Iterable [CheckpointBlock ], num_hidden = 1 , sqrt = False ) -> None :
866889 super ().__init__ ()
867890
868891 self ._modules = {}
@@ -872,6 +895,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
872895 self ._modules [str (i )] = module
873896 self .add_module (str (i ), module )
874897
898+ self .num_hidden = num_hidden
899+
875900 if sqrt :
876901 length = len (self )
877902 num_save_needed = 0
@@ -901,12 +926,11 @@ def __iter__(self) -> Iterator[CheckpointBlock]:
901926 def __getitem__ (self , index : Union [int , str ]) -> CheckpointBlock :
902927 return self ._modules [str (index )]
903928
904- def forward (self , hidden_state , * args , return_hidden_states = False ):
929+ def forward (self , * args , return_hidden_states = False ):
905930 self .return_hidden_states = return_hidden_states
906931 placeholder = torch .tensor ([], requires_grad = torch .is_grad_enabled ())
907- outputs = OpTransformerBlockList .apply (placeholder , self , self .save_list , hidden_state , * args )
908- last_hidden , middle_hiddens = outputs [:2 ]
932+ outputs = OpTransformerBlockList .apply (placeholder , self , self .save_list , self .num_hidden , * args )
909933 if return_hidden_states :
910- return last_hidden , middle_hiddens
934+ return tuple ( outputs [: 2 * self . num_hidden ])
911935 else :
912- return last_hidden
936+ return tuple ( outputs [: self . num_hidden ]) if self . num_hidden > 1 else outputs [ 0 ]
0 commit comments