Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,12 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)

is_legacy_loading = False

if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
Expand All @@ -883,7 +883,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
Expand Down
35 changes: 6 additions & 29 deletions src/diffusers/pipelines/kolors/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ def forward(self, hidden_states: torch.Tensor):
return (self.weight * hidden_states).to(input_dtype)


def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs


class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
Expand Down Expand Up @@ -314,7 +307,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)

self.core_attention = CoreAttention(config, self.layer_number)
Expand All @@ -325,7 +317,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
config.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config),
)

def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
Expand Down Expand Up @@ -449,7 +440,6 @@ def __init__(self, config: ChatGLMConfig, device=None):
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config),
)

def swiglu(x):
Expand All @@ -459,9 +449,7 @@ def swiglu(x):
self.activation_func = swiglu

# Project back to h.
self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
)
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)

def forward(self, hidden_states):
# [s, b, 4hp]
Expand All @@ -488,18 +476,14 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):

LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)

# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout

# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)

# MLP
self.mlp = MLP(config, device=device)
Expand Down Expand Up @@ -569,9 +553,7 @@ def build_layer(layer_number):
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)

self.gradient_checkpointing = False

Expand Down Expand Up @@ -679,9 +661,7 @@ def __init__(self, config: ChatGLMConfig, device=None):

self.hidden_size = config.hidden_size
# Word embeddings (parallel).
self.word_embeddings = nn.Embedding(
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
)
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
self.fp32_residual_connection = config.fp32_residual_connection

def forward(self, input_ids):
Expand Down Expand Up @@ -784,16 +764,13 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)

self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
dtype=config.torch_dtype,
**init_kwargs,
)
self.pre_seq_len = config.pre_seq_len
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
Expand All @@ -703,7 +703,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
Expand Down Expand Up @@ -1456,8 +1456,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/kolors/test_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/kolors/test_kolors_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/pag/test_pag_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

Expand Down
Loading