Skip to content

Commit 6417f5d

Browse files
authored
Merge pull request #580 from kohya-ss/dev
fix clip skip not working in weighted caption training and sample gen
2 parents 363f1df + 8088c04 commit 6417f5d

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
140140

141141
## Change History
142142

143+
### 8 Jun. 2023, 2023/06/08
144+
145+
- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.
146+
- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。
147+
143148
### 6 Jun. 2023, 2023/06/06
144149

145150
- Fix `train_network.py` to probably work with older versions of LyCORIS.

library/custom_train_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,6 @@ def get_unweighted_text_embeddings(
265265
text_embedding = enc_out["hidden_states"][-clip_skip]
266266
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
267267

268-
# cover the head and the tail by the starting and the ending tokens
269-
text_input_chunk[:, 0] = text_input[0, 0]
270-
text_input_chunk[:, -1] = text_input[0, -1]
271-
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
272-
273268
if no_boseos_middle:
274269
if i == 0:
275270
# discard the ending token
@@ -284,7 +279,12 @@ def get_unweighted_text_embeddings(
284279
text_embeddings.append(text_embedding)
285280
text_embeddings = torch.concat(text_embeddings, axis=1)
286281
else:
287-
text_embeddings = text_encoder(text_input)[0]
282+
if clip_skip is None or clip_skip == 1:
283+
text_embeddings = text_encoder(text_input)[0]
284+
else:
285+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
286+
text_embeddings = enc_out["hidden_states"][-clip_skip]
287+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
288288
return text_embeddings
289289

290290

library/lpw_stable_diffusion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
245245
text_embedding = enc_out["hidden_states"][-clip_skip]
246246
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
247247

248-
# cover the head and the tail by the starting and the ending tokens
249-
text_input_chunk[:, 0] = text_input[0, 0]
250-
text_input_chunk[:, -1] = text_input[0, -1]
251-
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
252-
253248
if no_boseos_middle:
254249
if i == 0:
255250
# discard the ending token
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
264259
text_embeddings.append(text_embedding)
265260
text_embeddings = torch.concat(text_embeddings, axis=1)
266261
else:
267-
text_embeddings = pipe.text_encoder(text_input)[0]
262+
if clip_skip is None or clip_skip == 1:
263+
text_embeddings = pipe.text_encoder(text_input)[0]
264+
else:
265+
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
266+
text_embeddings = enc_out["hidden_states"][-clip_skip]
267+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
268268
return text_embeddings
269269

270270

0 commit comments

Comments
 (0)