Skip to content

Commit 6b2b6c4

Browse files
committed
Fix: HunyuanImage-2.1
1 parent 4edb6d4 commit 6b2b6c4

File tree

3 files changed

+0
-157
lines changed

3 files changed

+0
-157
lines changed

diffsynth/utils/state_dict_converters/hunyuan_dit_converter.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,3 @@
1-
# def HunyuanDiTStateDictConverter(state_dict):
2-
# new_dict = {}
3-
# for k, w in state_dict.items():
4-
# if k.startswith("double_blocks") and "attn_qkv.weight" in k:
5-
# hidden_size = w.shape[1]
6-
# k1 = k.replace("attn_qkv.weight", "attn_q.weight")
7-
# w1 = w[:hidden_size, :]
8-
# new_dict[k1] = w1
9-
# k2 = k.replace("attn_qkv.weight", "attn_k.weight")
10-
# w2 = w[hidden_size : 2 * hidden_size, :]
11-
# new_dict[k2] = w2
12-
# k3 = k.replace("attn_qkv.weight", "attn_v.weight")
13-
# w3 = w[-hidden_size:, :]
14-
# new_dict[k3] = w3
15-
# elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
16-
# hidden_size = w.shape[0] // 3
17-
# k1 = k.replace("attn_qkv.bias", "attn_q.bias")
18-
# w1 = w[:hidden_size]
19-
# new_dict[k1] = w1
20-
# k2 = k.replace("attn_qkv.bias", "attn_k.bias")
21-
# w2 = w[hidden_size : 2 * hidden_size]
22-
# new_dict[k2] = w2
23-
# k3 = k.replace("attn_qkv.bias", "attn_v.bias")
24-
# w3 = w[-hidden_size:]
25-
# new_dict[k3] = w3
26-
# elif k.startswith("single_blocks") and "linear1" in k:
27-
# hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
28-
# k1 = k.replace("linear1", "linear1_q")
29-
# w1 = w[:hidden_size]
30-
# new_dict[k1] = w1
31-
# k2 = k.replace("linear1", "linear1_k")
32-
# w2 = w[hidden_size : 2 * hidden_size]
33-
# new_dict[k2] = w2
34-
# k3 = k.replace("linear1", "linear1_v")
35-
# w3 = w[2 * hidden_size : 3 * hidden_size]
36-
# new_dict[k3] = w3
37-
# k4 = k.replace("linear1", "linear1_mlp")
38-
# w4 = w[3 * hidden_size :]
39-
# new_dict[k4] = w4
40-
# elif k.startswith("single_blocks") and "linear2" in k:
41-
# k1 = k.replace("linear2", "linear2.fc")
42-
# new_dict[k1] = w
43-
# else:
44-
# new_dict[k] = w
45-
# return new_dict
46-
47-
48-
49-
501
def HunyuanDiTStateDictConverter(state_dict):
512
new_dict = {}
523

diffsynth/utils/state_dict_converters/hunyuan_t5_converter.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,3 @@
1-
# import torch
2-
3-
# def HunyuanT5StateDictConverter(state_dict: dict) -> dict:
4-
# if 'state_dict' in state_dict:
5-
# sd = state_dict["state_dict"]
6-
# else:
7-
# sd = state_dict
8-
9-
# prefix = 'module.text_tower.encoder.'
10-
11-
# first_key = next(iter(sd.keys()), None)
12-
# if first_key is None or not first_key.startswith(prefix):
13-
# return sd
14-
15-
# newsd = {}
16-
# for k, v in sd.items():
17-
# if k.startswith(prefix):
18-
# newsd[k[len(prefix):]] = v
19-
# else:
20-
# newsd[k] = v
21-
22-
# return newsd
23-
24-
25-
26-
27-
# import torch
28-
29-
# def HunyuanT5StateDictConverter(state_dict: dict) -> dict:
30-
# if 'state_dict' in state_dict_dict:
31-
# sd = state_dict_dict["state_dict"]
32-
# elif 'model' in state_dict_dict:
33-
# sd = state_dict_dict["model"]
34-
# else:
35-
# sd = state_dict_dict
36-
37-
# prefix = 'module.text_tower.encoder.'
38-
39-
# first_key = next(iter(sd.keys()), None)
40-
# if first_key is None or not first_key.startswith(prefix):
41-
# return sd
42-
43-
# newsd = {}
44-
# for k, v in sd.items():
45-
# if k.startswith(prefix):
46-
# newsd[k[len(prefix):]] = v
47-
# else:
48-
# newsd[k] = v
49-
50-
# return newsd
51-
521
import torch
532

543
def HunyuanT5StateDictConverter(state_dict: dict) -> dict:

diffsynth/utils/state_dict_converters/hunyuan_vae_converter.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,3 @@
1-
# import torch
2-
3-
# def HunyuanVAEStateDictConverter(state_dict: dict) -> dict:
4-
# if 'state_dict' in state_dict:
5-
# sd = state_dict["state_dict"]
6-
# elif 'model' in state_dict:
7-
# sd = state_dict["model"]
8-
# else:
9-
# sd = state_dict
10-
11-
# converted_state_dict = {}
12-
# for key, value in sd.items():
13-
# if 'weight' in key:
14-
# if len(value.shape) == 5 and value.shape[2] == 1:
15-
# converted_state_dict[key] = value.squeeze(2)
16-
# else:
17-
# converted_state_dict[key] = value
18-
# else:
19-
# converted_state_dict[key] = value
20-
21-
# return converted_state_dict
22-
23-
24-
25-
# import torch
26-
# def HunyuanVAEStateDictConverter(state_dict: dict) -> dict:
27-
# if 'state_dict' in state_dict:
28-
# sd = state_dict["state_dict"]
29-
# elif 'model' in state_dict:
30-
# sd = state_dict["model"]
31-
# else:
32-
# sd = state_dict
33-
34-
# first_key = next(iter(sd.keys()), None)
35-
# prefix = None
36-
# if first_key is not None:
37-
# if first_key.startswith("model."):
38-
# prefix = "model."
39-
# elif first_key.startswith("vae."):
40-
# prefix = "vae."
41-
42-
# converted_state_dict = {}
43-
# for key, value in sd.items():
44-
45-
# if prefix is not None and key.startswith(prefix):
46-
# key = key[len(prefix):]
47-
48-
# if 'weight' in key:
49-
# if len(value.shape) == 5 and value.shape[2] == 1:
50-
# converted_state_dict[key] = value.squeeze(2)
51-
# else:
52-
# converted_state_dict[key] = value
53-
# else:
54-
# converted_state_dict[key] = value
55-
56-
# return converted_state_dict
57-
581
import torch
592

603
def HunyuanVAEStateDictConverter(state_dict: dict) -> dict:

0 commit comments

Comments
 (0)