Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 138 additions & 56 deletions library/strategy_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,81 +30,171 @@ def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Opt
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)

if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2


self.break_separator = "BREAK"

def _split_on_break(self, text: str) -> List[str]:
"""Split text on BREAK separator (case-sensitive), filtering empty segments."""
segments = text.split(self.break_separator)
# Filter out empty or whitespace-only segments
filtered = [seg.strip() for seg in segments if seg.strip()]
# Return at least one segment to maintain consistency
return filtered if filtered else [""]

def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Tokenize multiple segments and concatenate them."""
if len(segments) == 1:
# No BREAK present, use existing logic
if weighted:
return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True)
else:
tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length)
return tokens, None

# Multiple segments - tokenize each separately
all_tokens = []
all_weights = [] if weighted else None

for segment in segments:
if weighted:
seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True)
all_tokens.append(seg_tokens)
all_weights.append(seg_weights)
else:
seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length)
all_tokens.append(seg_tokens)

# Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len])
combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0)
combined_weights = None
if weighted:
combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0)

return combined_tokens, combined_weights

def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]


tokens_list = []
for t in text:
segments = self._split_on_break(t)
tokens, _ = self._tokenize_segments(segments, weighted=False)
tokens_list.append(tokens)

# Pad tokens to same length for stacking
max_length = max(t.shape[-1] for t in tokens_list)
padded_tokens = []
for tokens in tokens_list:
if tokens.shape[-1] < max_length:
# Pad with pad_token_id
pad_size = max_length - tokens.shape[-1]
if tokens.dim() == 2:
padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype)
tokens = torch.cat([tokens, padding], dim=1)
else:
padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype)
tokens = torch.cat([tokens, padding], dim=0)
padded_tokens.append(tokens)

return [torch.stack(padded_tokens, dim=0)]

def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text

tokens_list = []
weights_list = []
for t in text:
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
segments = self._split_on_break(t)
tokens, weights = self._tokenize_segments(segments, weighted=True)
tokens_list.append(tokens)
weights_list.append(weights)

return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]


class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip


def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor:
"""Encode tokens with optional CLIP skip."""
if self.clip_skip is None:
return text_encoder(tokens)[0]

enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
hidden_states = enc_out["hidden_states"][-self.clip_skip]
return text_encoder.text_model.final_layer_norm(hidden_states)

def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor,
max_token_length: int, model_max_length: int,
tokenizer: Any) -> torch.Tensor:
"""Reconstruct embeddings from chunked encoding."""
v1 = tokenizer.pad_token_id == tokenizer.eos_token_id
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>

if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2]
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == tokenizer.eos_token:
chunk[j, 0] = chunk[j, 1]
states_list.append(chunk)
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2])
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))

return torch.cat(states_list, dim=1)

def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
"""Apply weights for single chunk case (no max_token_length)."""
return encoder_hidden_states * weights.squeeze(1).unsqueeze(2)

def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
"""Apply weights for multi-chunk case (with max_token_length)."""
for i in range(weights.shape[1]):
start_idx = i * 75 + 1
end_idx = i * 75 + 76
encoder_hidden_states[:, start_idx:end_idx] = (
encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1)
)
return encoder_hidden_states

def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy

# tokens: b,n,77

b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77


tokens = tokens.reshape((-1, model_max_length))
tokens = tokens.to(text_encoder.device)

if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)

# bs*3, 77, 768 or 1024

encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens)
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))

if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)

encoder_hidden_states = self._reconstruct_embeddings(
encoder_hidden_states, tokens, max_token_length,
model_max_length, sd_tokenize_strategy.tokenizer
)

return [encoder_hidden_states]

def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
Expand All @@ -113,23 +203,15 @@ def encode_tokens_with_weights(
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]

weights = weights_list[0].to(encoder_hidden_states.device)

# apply weights
if weights.shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)

if weights.shape[1] == 1:
encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for i in range(weights.shape[1]):
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
:, i, 1:-1
].unsqueeze(-1)

encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights)

return [encoder_hidden_states]


class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
Expand Down
140 changes: 140 additions & 0 deletions tests/library/test_strategy_sd_text_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import pytest
import torch
from unittest.mock import Mock

from library.strategy_sd import SdTextEncodingStrategy


class TestSdTextEncodingStrategy:
@pytest.fixture
def strategy(self):
"""Create strategy instance with default settings."""
return SdTextEncodingStrategy(clip_skip=None)

@pytest.fixture
def strategy_with_clip_skip(self):
"""Create strategy instance with CLIP skip enabled."""
return SdTextEncodingStrategy(clip_skip=2)

@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.model_max_length = 77
tokenizer.pad_token_id = 0
tokenizer.eos_token = 2
tokenizer.eos_token_id = 2
return tokenizer

@pytest.fixture
def mock_text_encoder(self):
"""Create a mock text encoder."""
encoder = Mock()
encoder.device = torch.device("cpu")

def encode_side_effect(tokens, output_hidden_states=False, return_dict=False):
batch_size = tokens.shape[0]
seq_len = tokens.shape[1]
hidden_size = 768

# Create deterministic hidden states
hidden_state = torch.ones(batch_size, seq_len, hidden_size) * 0.5

if return_dict:
result = {
"hidden_states": [
hidden_state * 0.8,
hidden_state * 0.9,
hidden_state * 1.0,
]
}
return result
else:
return [hidden_state]

encoder.side_effect = encode_side_effect
encoder.text_model = Mock()
encoder.text_model.final_layer_norm = lambda x: x

return encoder

@pytest.fixture
def mock_tokenize_strategy(self, mock_tokenizer):
"""Create a mock tokenize strategy."""
strategy = Mock()
strategy.tokenizer = mock_tokenizer
return strategy

# Test _encode_with_clip_skip
def test_encode_without_clip_skip(self, strategy, mock_text_encoder):
"""Test encoding without CLIP skip."""
tokens = torch.arange(154).reshape(2, 77)
result = strategy._encode_with_clip_skip(mock_text_encoder, tokens)
assert result.shape == (2, 77, 768)
# Verify deterministic output
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))

def test_encode_with_clip_skip(self, strategy_with_clip_skip, mock_text_encoder):
"""Test encoding with CLIP skip."""
tokens = torch.arange(154).reshape(2, 77)
result = strategy_with_clip_skip._encode_with_clip_skip(mock_text_encoder, tokens)
assert result.shape == (2, 77, 768)
# With clip_skip=2, should use second-to-last hidden state (0.5 * 0.9 = 0.45)
assert torch.allclose(result[0, 0, 0], torch.tensor(0.45))

# Test _apply_weights_single_chunk
def test_apply_weights_single_chunk(self, strategy):
"""Test applying weights for single chunk case."""
encoder_hidden_states = torch.ones(2, 77, 768)
weights = torch.ones(2, 1, 77) * 0.5
result = strategy._apply_weights_single_chunk(encoder_hidden_states, weights)
assert result.shape == (2, 77, 768)
# Verify weights were applied: 1.0 * 0.5 = 0.5
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))

# Test _apply_weights_multi_chunk
def test_apply_weights_multi_chunk(self, strategy):
"""Test applying weights for multi-chunk case."""
# Simulating 2 chunks: 2*75+2 = 152 tokens
encoder_hidden_states = torch.ones(2, 152, 768)
weights = torch.ones(2, 2, 77) * 0.5
result = strategy._apply_weights_multi_chunk(encoder_hidden_states, weights)
assert result.shape == (2, 152, 768)
# Check that weights were applied to middle sections
assert torch.allclose(result[0, 1, 0], torch.tensor(0.5))
assert torch.allclose(result[0, 76, 0], torch.tensor(0.5))

# Integration tests
def test_encode_tokens_basic(self, strategy, mock_tokenize_strategy, mock_text_encoder):
"""Test basic token encoding flow."""
tokens = torch.arange(154).reshape(2, 1, 77)
models = [mock_text_encoder]
tokens_list = [tokens]

result = strategy.encode_tokens(mock_tokenize_strategy, models, tokens_list)

assert len(result) == 1
assert result[0].shape[0] == 2 # batch size
assert result[0].shape[2] == 768 # hidden size
# Verify deterministic output
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.5))

def test_encode_tokens_with_weights_single_chunk(self, strategy, mock_tokenize_strategy, mock_text_encoder):
"""Test weighted encoding with single chunk."""
tokens = torch.arange(154).reshape(2, 1, 77)
weights = torch.ones(2, 1, 77) * 0.5
models = [mock_text_encoder]
tokens_list = [tokens]
weights_list = [weights]

result = strategy.encode_tokens_with_weights(mock_tokenize_strategy, models, tokens_list, weights_list)

assert len(result) == 1
assert result[0].shape[0] == 2
assert result[0].shape[2] == 768
# Verify weights were applied: 0.5 (encoder output) * 0.5 (weight) = 0.25
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.25))


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading