@@ -31,18 +31,22 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
3131
3232 def norm (self , input , weight ):
3333 input_dtype = input .dtype
34+ input_shape = input .shape
35+ input = input .view (- 1 , self .tp_padding_head_num * self .head_dim_ )
3436 input = input .to (torch .float32 )
3537 variance = input .pow (2 ).mean (- 1 , keepdim = True )
3638 input = input * torch .rsqrt (variance + self .eps_ )
37- return weight * input .to (input_dtype )
39+ out = weight * input .to (input_dtype )
40+ out = out .reshape (input_shape )
41+ return out
3842
3943 def tp_norm (self , input , weight ):
4044 input_shape = input .shape
4145 input = input .view (- 1 , self .tp_padding_head_num * self .head_dim_ )
4246 input_dtype = input .dtype
4347 input = input .to (torch .float32 )
4448 tp_variance = input .pow (2 ).sum (- 1 , keepdim = True )
45- dist .all_reduce (tp_variance , op = dist .ReduceOp .SUM , async_op = False )
49+ # dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False)
4650 variance = tp_variance / self .embed_dim_
4751 input = input * torch .rsqrt (variance + self .eps_ )
4852 out = weight * input .to (input_dtype )
@@ -75,8 +79,8 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
7579 )
7680
7781 def _qk_norm (self , q , k , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
78- q_norm = self .tp_norm (q , layer_weight .q_norm_weight_ .weight )
79- k_norm = self .tp_norm (k , layer_weight .k_norm_weight_ .weight )
82+ q_norm = self .norm (q , layer_weight .q_norm_weight_ .weight )
83+ k_norm = self .norm (k , layer_weight .k_norm_weight_ .weight )
8084 return q_norm , k_norm
8185
8286 def _get_qkv (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
@@ -85,6 +89,9 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens
8589 qkv = layer_weight .qkv_proj .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = False )
8690 qkv = qkv .view (batch_size , seq_len , 3 , - 1 , self .head_dim_ )
8791 q , k , v = qkv .unbind (2 )
92+ q = q .contiguous ()
93+ k = k .contiguous ()
94+ v = v .contiguous ()
8895 return q , k , v
8996
9097 def _context_attention_kernel (self , q , k , v ) -> torch .Tensor :
0 commit comments