Skip to content

Commit 6857b44

Browse files
authored
feat: add text encoder files (mindspore-lab#1481)
- Add 8 text encoder files: bert, flux, hunyuan_video, llama, qwen_image, qwen_vl, sd3_clip, t5 Signed-off-by: vigo999 <zwiori1982@163.com>
1 parent 68f89ed commit 6857b44

File tree

8 files changed

+2160
-0
lines changed

8 files changed

+2160
-0
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import comfy.ops
2+
from comfy.ldm.modules.attention import optimized_attention_for_device
3+
from mindspore_patch.utils import dtype_to_max
4+
5+
import mindspore
6+
from mindspore import mint
7+
8+
9+
class BertAttention(mindspore.nn.Cell):
10+
def __init__(self, embed_dim, heads, dtype, device, operations):
11+
super().__init__()
12+
13+
self.heads = heads
14+
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=None)
15+
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=None)
16+
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=None)
17+
18+
def construct(self, x, mask=None, optimized_attention=None):
19+
q = self.query(x)
20+
k = self.key(x)
21+
v = self.value(x)
22+
23+
out = optimized_attention(q, k, v, self.heads, mask)
24+
return out
25+
26+
27+
class BertOutput(mindspore.nn.Cell):
28+
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
29+
super().__init__()
30+
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=None)
31+
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=None)
32+
# self.dropout = nn.Dropout(0.0)
33+
34+
def construct(self, x, y):
35+
x = self.dense(x)
36+
# hidden_states = self.dropout(hidden_states)
37+
x = self.LayerNorm(x + y)
38+
return x
39+
40+
41+
class BertAttentionBlock(mindspore.nn.Cell):
42+
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
43+
super().__init__()
44+
self.self = BertAttention(embed_dim, heads, dtype, None, operations)
45+
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, None, operations)
46+
47+
def construct(self, x, mask, optimized_attention):
48+
y = self.self(x, mask, optimized_attention)
49+
return self.output(y, x)
50+
51+
52+
class BertIntermediate(mindspore.nn.Cell):
53+
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
54+
super().__init__()
55+
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=None)
56+
57+
def construct(self, x):
58+
x = self.dense(x)
59+
return mint.functional.gelu(x)
60+
61+
62+
class BertBlock(mindspore.nn.Cell):
63+
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
64+
super().__init__()
65+
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, None, operations)
66+
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, None, operations)
67+
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, None, operations)
68+
69+
def construct(self, x, mask, optimized_attention):
70+
x = self.attention(x, mask, optimized_attention)
71+
y = self.intermediate(x)
72+
return self.output(y, x)
73+
74+
75+
class BertEncoder(mindspore.nn.Cell):
76+
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
77+
super().__init__()
78+
self.layer = mindspore.nn.CellList(
79+
[
80+
BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, None, operations)
81+
for i in range(num_layers)
82+
]
83+
)
84+
85+
def construct(self, x, mask=None, intermediate_output=None):
86+
optimized_attention = optimized_attention_for_device(None, mask=mask is not None, small_input=True)
87+
88+
if intermediate_output is not None:
89+
if intermediate_output < 0:
90+
intermediate_output = len(self.layer) + intermediate_output
91+
92+
intermediate = None
93+
for i, l in enumerate(self.layer):
94+
x = l(x, mask, optimized_attention)
95+
if i == intermediate_output:
96+
intermediate = x.clone()
97+
return x, intermediate
98+
99+
100+
class BertEmbeddings(mindspore.nn.Cell):
101+
def __init__(
102+
self,
103+
vocab_size,
104+
max_position_embeddings,
105+
type_vocab_size,
106+
pad_token_id,
107+
embed_dim,
108+
layer_norm_eps,
109+
dtype,
110+
device,
111+
operations,
112+
):
113+
super().__init__()
114+
self.word_embeddings = operations.Embedding(
115+
vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=None
116+
)
117+
self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=None)
118+
self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=None)
119+
120+
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=None)
121+
122+
def construct(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
123+
if embeds is not None:
124+
x = embeds
125+
else:
126+
x = self.word_embeddings(input_tokens, out_dtype=dtype)
127+
x += comfy.ops.cast_to_input(self.position_embeddings.weight[: x.shape[1]], x)
128+
if token_type_ids is not None:
129+
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
130+
else:
131+
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
132+
x = self.LayerNorm(x)
133+
return x
134+
135+
136+
class BertModel_(mindspore.nn.Cell):
137+
def __init__(self, config_dict, dtype, device, operations):
138+
super().__init__()
139+
embed_dim = config_dict["hidden_size"]
140+
layer_norm_eps = config_dict["layer_norm_eps"]
141+
142+
self.embeddings = BertEmbeddings(
143+
config_dict["vocab_size"],
144+
config_dict["max_position_embeddings"],
145+
config_dict["type_vocab_size"],
146+
config_dict["pad_token_id"],
147+
embed_dim,
148+
layer_norm_eps,
149+
dtype,
150+
None,
151+
operations,
152+
)
153+
self.encoder = BertEncoder(
154+
config_dict["num_hidden_layers"],
155+
embed_dim,
156+
config_dict["intermediate_size"],
157+
config_dict["num_attention_heads"],
158+
layer_norm_eps,
159+
dtype,
160+
None,
161+
operations,
162+
)
163+
164+
def construct(
165+
self,
166+
input_tokens,
167+
attention_mask=None,
168+
embeds=None,
169+
num_tokens=None,
170+
intermediate_output=None,
171+
final_layer_norm_intermediate=True,
172+
dtype=None,
173+
embeds_info=[],
174+
):
175+
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
176+
mask = None
177+
if attention_mask is not None:
178+
mask = 1.0 - attention_mask.to(x.dtype).reshape(
179+
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
180+
).expand((attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]))
181+
mask = mask.masked_fill(mask.to(mindspore.bool), -dtype_to_max(x.dtype))
182+
183+
x, i = self.encoder(x, mask, intermediate_output)
184+
return x, i
185+
186+
187+
class BertModel(mindspore.nn.Cell):
188+
def __init__(self, config_dict, dtype, device, operations):
189+
super().__init__()
190+
self.bert = BertModel_(config_dict, dtype, None, operations)
191+
self.num_layers = config_dict["num_hidden_layers"]
192+
193+
def get_input_embeddings(self):
194+
return self.bert.embeddings.word_embeddings
195+
196+
def set_input_embeddings(self, embeddings):
197+
self.bert.embeddings.word_embeddings = embeddings
198+
199+
def construct(self, *args, **kwargs):
200+
return self.bert(*args, **kwargs)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os
2+
3+
import comfy.model_management
4+
import comfy.text_encoders.sd3_clip
5+
import comfy.text_encoders.t5
6+
from comfy import sd1_clip
7+
from transformers import T5TokenizerFast
8+
9+
import mindspore
10+
from mindspore import mint
11+
12+
13+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
14+
def __init__(self, embedding_directory=None, tokenizer_data={}):
15+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
16+
super().__init__(
17+
tokenizer_path,
18+
embedding_directory=embedding_directory,
19+
pad_with_end=False,
20+
embedding_size=4096,
21+
embedding_key="t5xxl",
22+
tokenizer_class=T5TokenizerFast,
23+
has_start_token=False,
24+
pad_to_max_length=False,
25+
max_length=99999999,
26+
min_length=256,
27+
tokenizer_data=tokenizer_data,
28+
)
29+
30+
31+
class FluxTokenizer:
32+
def __init__(self, embedding_directory=None, tokenizer_data={}):
33+
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
34+
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
35+
36+
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
37+
out = {}
38+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
39+
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
40+
return out
41+
42+
def untokenize(self, token_weight_pair):
43+
return self.clip_l.untokenize(token_weight_pair)
44+
45+
def state_dict(self):
46+
return {}
47+
48+
49+
class FluxClipModel(mindspore.nn.Cell):
50+
def __init__(self, dtype_t5=None, device=None, dtype=None, model_options={}):
51+
super().__init__()
52+
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype)
53+
self.clip_l = sd1_clip.SDClipModel(dtype=dtype, return_projected_pooled=False, model_options=model_options)
54+
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(dtype=dtype_t5, model_options=model_options)
55+
self.dtypes = set([dtype, dtype_t5])
56+
57+
def set_clip_options(self, options):
58+
self.clip_l.set_clip_options(options)
59+
self.t5xxl.set_clip_options(options)
60+
61+
def reset_clip_options(self):
62+
self.clip_l.reset_clip_options()
63+
self.t5xxl.reset_clip_options()
64+
65+
def encode_token_weights(self, token_weight_pairs):
66+
token_weight_pairs_l = token_weight_pairs["l"]
67+
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
68+
69+
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
70+
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
71+
return t5_out, l_pooled
72+
73+
def load_sd(self, sd):
74+
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
75+
sd = {f"clip_l.transformer.{k}": v for k, v in sd.items()}
76+
return self.clip_l.load_sd(sd)
77+
else:
78+
sd = {f"t5xxl.transformer.{k}": v for k, v in sd.items()}
79+
return self.t5xxl.load_sd(sd)
80+
81+
82+
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
83+
class FluxClipModel_(FluxClipModel):
84+
def __init__(self, device=None, dtype=None, model_options={}):
85+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
86+
# model_options = model_options.copy()
87+
# model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
88+
raise NotImplementedError
89+
super().__init__(dtype_t5=dtype_t5, device=None, dtype=dtype, model_options=model_options)
90+
91+
return FluxClipModel_

0 commit comments

Comments
 (0)