@@ -49,23 +49,26 @@ def _to_16bit(inputs):
4949 else :
5050 return inputs
5151
52- self .specs .append (lambda inputss : tuple (_to_16bit (inputs ) for inputs in inputss ))
52+ self .specs .append (lambda inputss : tuple (tuple ( _to_16bit (inputs ) ) for inputs in inputss ))
5353
5454 # Embedding layer
5555 self .specs .append (TiedLayerSpec ('embed' ,
5656 EmbeddingPipe ,
5757 args .hidden_size ,
5858 args .padded_vocab_size ,
5959 args .hidden_dropout ,
60+ forward_fn = lambda module , inputs , targets : (module (* inputs ), module (* targets )),
6061 init_method = init_method ,
6162 num_tokentypes = num_tokentypes ,
6263 tied_weight_attr = 'word_embeddings_weight' ))
6364
6465 assert hasattr (args , 'attn_mask' ), "Deepspeed integration should have attention mask s"
66+ # Drop everything beside tokens
67+ self .specs .append (lambda inputs , targets : (inputs [0 ], targets [0 ]))
6568 if args .fp32_residual_connection :
66- self .specs .append (lambda x : x .transpose (0 , 1 ).contiguous ().float ())
69+ self .specs .append (lambda input_tokens , target_tokens : ( input_tokens .transpose (0 , 1 ).contiguous ().float (), target_tokens . transpose ( 0 , 1 ). contiguous (). float () ))
6770 else :
68- self .specs .append (lambda x : x .transpose (0 , 1 ).contiguous ())
71+ self .specs .append (lambda input_tokens , target_tokens : ( input_tokens .transpose (0 , 1 ).contiguous (), target_tokens . transpose ( 0 , 1 ). contiguous () ))
6972
7073 ### ----- Encoder -----
7174 for layer_idx in range (args .num_layers ):
@@ -74,22 +77,21 @@ def _to_16bit(inputs):
7477 f"block_{ layer_idx } " ,
7578 ParallelTransformerLayerPipe ,
7679 init_method = init_method ,
77- # Inputs: (input_tokens, target_tokens,
78- forward_fn = lambda module , * inputs : ,
80+ forward_fn = lambda module , input_tokens , target_tokens : (module (input_tokens ), target_tokens ),
7981 output_layer_init_method = scaled_init_method_normal (args .init_method_std ,
8082 args .num_layers ),
8183 layer_type = LayerType .encoder ,
8284 layer_number = layer_idx ,
8385 self_attn_mask_type = AttnMaskType .causal ,
84- tied_weight_attrs = ["input_layernorm" , " self_attention" , "post_attention_layernorm " , "mlp" ]
86+ tied_weight_attrs = ["self_attention" , "mlp" ]
8587 ))
8688
8789 # Final layernorm after encoder layers
8890 self .specs .append (
89- TiedLayerSpec (
90- "final_layer_norm" ,
91+ LayerSpec (
9192 LayerNorm ,
9293 args .hidden_size ,
94+ forward_fn = lambda module , input_tokens , target_tokens : (module (input_tokens ), target_tokens ),
9395 eps = args .layernorm_epsilon
9496 ))
9597
@@ -100,19 +102,22 @@ def _to_16bit(inputs):
100102 f"block_{ layer_idx } " ,
101103 ParallelTransformerLayerPipe ,
102104 init_method = init_method ,
105+ forward_fn = lambda module , encoded_tokens , target_tokens : (encoded_tokens , module (target_tokens , encoder_output = encoded_tokens )),
103106 output_layer_init_method = scaled_init_method_normal (args .init_method_std ,
104107 args .num_layers ),
105108 layer_number = layer_idx ,
106109 layer_type = LayerType .decoder ,
107110 self_attn_mask_type = AttnMaskType .padding ,
108- tied_weight_attrs = ["input_layernorm" , " self_attention" , "post_attention_layernorm " , "mlp" ]
111+ tied_weight_attrs = ["self_attention" , "mlp" ]
109112 )
110113 )
111114
115+ # Drop encoded tokens
116+ self .specs .append (lambda encoded_tokens , target_tokens : target_tokens )
117+
112118 # Final layernorm after decoder layers
113119 self .specs .append (
114- TiedLayerSpec (
115- "final_layer_norm" ,
120+ LayerSpec (
116121 LayerNorm ,
117122 args .hidden_size ,
118123 eps = args .layernorm_epsilon
0 commit comments