Skip to content

Commit 46a404b

Browse files
committed
djkajka
1 parent 36b2367 commit 46a404b

File tree

4 files changed

+9
-32
lines changed

4 files changed

+9
-32
lines changed

src/diffusers/pipelines/kolors/text_encoder.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ def forward(self, hidden_states: torch.Tensor):
104104
return (self.weight * hidden_states).to(input_dtype)
105105

106106

107-
def _config_to_kwargs(args):
108-
common_kwargs = {
109-
"dtype": args.torch_dtype,
110-
}
111-
return common_kwargs
112-
113-
114107
class CoreAttention(torch.nn.Module):
115108
def __init__(self, config: ChatGLMConfig, layer_number):
116109
super(CoreAttention, self).__init__()
@@ -314,7 +307,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
314307
self.qkv_hidden_size,
315308
bias=config.add_bias_linear or config.add_qkv_bias,
316309
device=device,
317-
**_config_to_kwargs(config),
318310
)
319311

320312
self.core_attention = CoreAttention(config, self.layer_number)
@@ -325,7 +317,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
325317
config.hidden_size,
326318
bias=config.add_bias_linear,
327319
device=device,
328-
**_config_to_kwargs(config),
329320
)
330321

331322
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
@@ -449,7 +440,6 @@ def __init__(self, config: ChatGLMConfig, device=None):
449440
config.ffn_hidden_size * 2,
450441
bias=self.add_bias,
451442
device=device,
452-
**_config_to_kwargs(config),
453443
)
454444

455445
def swiglu(x):
@@ -459,9 +449,7 @@ def swiglu(x):
459449
self.activation_func = swiglu
460450

461451
# Project back to h.
462-
self.dense_4h_to_h = nn.Linear(
463-
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
464-
)
452+
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
465453

466454
def forward(self, hidden_states):
467455
# [s, b, 4hp]
@@ -488,18 +476,14 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
488476

489477
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
490478
# Layernorm on the input data.
491-
self.input_layernorm = LayerNormFunc(
492-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
493-
)
479+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
494480

495481
# Self attention.
496482
self.self_attention = SelfAttention(config, layer_number, device=device)
497483
self.hidden_dropout = config.hidden_dropout
498484

499485
# Layernorm on the attention output
500-
self.post_attention_layernorm = LayerNormFunc(
501-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
502-
)
486+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
503487

504488
# MLP
505489
self.mlp = MLP(config, device=device)
@@ -569,9 +553,7 @@ def build_layer(layer_number):
569553
if self.post_layer_norm:
570554
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
571555
# Final layer norm before output.
572-
self.final_layernorm = LayerNormFunc(
573-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
574-
)
556+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
575557

576558
self.gradient_checkpointing = False
577559

@@ -679,9 +661,7 @@ def __init__(self, config: ChatGLMConfig, device=None):
679661

680662
self.hidden_size = config.hidden_size
681663
# Word embeddings (parallel).
682-
self.word_embeddings = nn.Embedding(
683-
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
684-
)
664+
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
685665
self.fp32_residual_connection = config.fp32_residual_connection
686666

687667
def forward(self, input_ids):
@@ -784,16 +764,13 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
784764
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
785765
)
786766

787-
self.rotary_pos_emb = RotaryEmbedding(
788-
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
789-
)
767+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
790768
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
791769
self.output_layer = init_method(
792770
nn.Linear,
793771
config.hidden_size,
794772
config.padded_vocab_size,
795773
bias=False,
796-
dtype=config.torch_dtype,
797774
**init_kwargs,
798775
)
799776
self.pre_seq_len = config.pre_seq_len

tests/pipelines/kolors/test_kolors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9090
)
9191
torch.manual_seed(0)
9292
text_encoder = ChatGLMModel.from_pretrained(
93-
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
93+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
9494
)
9595
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
9696

tests/pipelines/kolors/test_kolors_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9494
)
9595
torch.manual_seed(0)
9696
text_encoder = ChatGLMModel.from_pretrained(
97-
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
97+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
9898
)
9999
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
100100

tests/pipelines/pag/test_pag_kolors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9999
)
100100
torch.manual_seed(0)
101101
text_encoder = ChatGLMModel.from_pretrained(
102-
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
102+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
103103
)
104104
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
105105

0 commit comments

Comments
 (0)