@@ -104,13 +104,6 @@ def forward(self, hidden_states: torch.Tensor):
104104 return (self .weight * hidden_states ).to (input_dtype )
105105
106106
107- def _config_to_kwargs (args ):
108- common_kwargs = {
109- "dtype" : args .torch_dtype ,
110- }
111- return common_kwargs
112-
113-
114107class CoreAttention (torch .nn .Module ):
115108 def __init__ (self , config : ChatGLMConfig , layer_number ):
116109 super (CoreAttention , self ).__init__ ()
@@ -314,7 +307,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
314307 self .qkv_hidden_size ,
315308 bias = config .add_bias_linear or config .add_qkv_bias ,
316309 device = device ,
317- ** _config_to_kwargs (config ),
318310 )
319311
320312 self .core_attention = CoreAttention (config , self .layer_number )
@@ -325,7 +317,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
325317 config .hidden_size ,
326318 bias = config .add_bias_linear ,
327319 device = device ,
328- ** _config_to_kwargs (config ),
329320 )
330321
331322 def _allocate_memory (self , inference_max_sequence_len , batch_size , device = None , dtype = None ):
@@ -449,7 +440,6 @@ def __init__(self, config: ChatGLMConfig, device=None):
449440 config .ffn_hidden_size * 2 ,
450441 bias = self .add_bias ,
451442 device = device ,
452- ** _config_to_kwargs (config ),
453443 )
454444
455445 def swiglu (x ):
@@ -459,9 +449,7 @@ def swiglu(x):
459449 self .activation_func = swiglu
460450
461451 # Project back to h.
462- self .dense_4h_to_h = nn .Linear (
463- config .ffn_hidden_size , config .hidden_size , bias = self .add_bias , device = device , ** _config_to_kwargs (config )
464- )
452+ self .dense_4h_to_h = nn .Linear (config .ffn_hidden_size , config .hidden_size , bias = self .add_bias , device = device )
465453
466454 def forward (self , hidden_states ):
467455 # [s, b, 4hp]
@@ -488,18 +476,14 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
488476
489477 LayerNormFunc = RMSNorm if config .rmsnorm else LayerNorm
490478 # Layernorm on the input data.
491- self .input_layernorm = LayerNormFunc (
492- config .hidden_size , eps = config .layernorm_epsilon , device = device , dtype = config .torch_dtype
493- )
479+ self .input_layernorm = LayerNormFunc (config .hidden_size , eps = config .layernorm_epsilon , device = device )
494480
495481 # Self attention.
496482 self .self_attention = SelfAttention (config , layer_number , device = device )
497483 self .hidden_dropout = config .hidden_dropout
498484
499485 # Layernorm on the attention output
500- self .post_attention_layernorm = LayerNormFunc (
501- config .hidden_size , eps = config .layernorm_epsilon , device = device , dtype = config .torch_dtype
502- )
486+ self .post_attention_layernorm = LayerNormFunc (config .hidden_size , eps = config .layernorm_epsilon , device = device )
503487
504488 # MLP
505489 self .mlp = MLP (config , device = device )
@@ -569,9 +553,7 @@ def build_layer(layer_number):
569553 if self .post_layer_norm :
570554 LayerNormFunc = RMSNorm if config .rmsnorm else LayerNorm
571555 # Final layer norm before output.
572- self .final_layernorm = LayerNormFunc (
573- config .hidden_size , eps = config .layernorm_epsilon , device = device , dtype = config .torch_dtype
574- )
556+ self .final_layernorm = LayerNormFunc (config .hidden_size , eps = config .layernorm_epsilon , device = device )
575557
576558 self .gradient_checkpointing = False
577559
@@ -679,9 +661,7 @@ def __init__(self, config: ChatGLMConfig, device=None):
679661
680662 self .hidden_size = config .hidden_size
681663 # Word embeddings (parallel).
682- self .word_embeddings = nn .Embedding (
683- config .padded_vocab_size , self .hidden_size , dtype = config .torch_dtype , device = device
684- )
664+ self .word_embeddings = nn .Embedding (config .padded_vocab_size , self .hidden_size , device = device )
685665 self .fp32_residual_connection = config .fp32_residual_connection
686666
687667 def forward (self , input_ids ):
@@ -784,16 +764,13 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
784764 config .hidden_size // config .num_attention_heads if config .kv_channels is None else config .kv_channels
785765 )
786766
787- self .rotary_pos_emb = RotaryEmbedding (
788- rotary_dim // 2 , original_impl = config .original_rope , device = device , dtype = config .torch_dtype
789- )
767+ self .rotary_pos_emb = RotaryEmbedding (rotary_dim // 2 , original_impl = config .original_rope , device = device )
790768 self .encoder = init_method (GLMTransformer , config , ** init_kwargs )
791769 self .output_layer = init_method (
792770 nn .Linear ,
793771 config .hidden_size ,
794772 config .padded_vocab_size ,
795773 bias = False ,
796- dtype = config .torch_dtype ,
797774 ** init_kwargs ,
798775 )
799776 self .pre_seq_len = config .pre_seq_len
0 commit comments