Skip to content

Commit 35dc65b

Browse files
committed
update chroma transformer params
1 parent f35ec17 commit 35dc65b

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,17 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
164164
return x
165165

166166

167-
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
168-
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
167+
class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
168+
def __init__(self, num_channels: int, out_dim: int):
169169
super().__init__()
170170

171-
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
172-
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
171+
self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
172+
self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
173173

174174
self.register_buffer(
175175
"mod_proj",
176176
get_timestep_embedding(
177-
torch.arange(out_dim) * 1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0
177+
torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
178178
),
179179
persistent=False,
180180
)
@@ -426,14 +426,16 @@ def __init__(
426426

427427
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
428428

429-
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
430-
factor=approximator_in_factor,
431-
hidden_dim=approximator_hidden_dim,
429+
self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
430+
num_channels=approximator_in_factor,
432431
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
433-
embedding_dim=self.inner_dim,
432+
)
433+
self.distilled_guidance_layer = ChromaApproximator(
434+
in_dim=64,
435+
out_dim=self.inner_dim,
436+
hidden_dim=approximator_hidden_dim,
434437
n_layers=approximator_layers,
435438
)
436-
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
437439

438440
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
439441
self.x_embedder = nn.Linear(in_channels, self.inner_dim)

tests/models/transformers/test_models_transformer_chroma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def prepare_init_args_and_inputs_for_common(self):
128128
"num_attention_heads": 2,
129129
"joint_attention_dim": 32,
130130
"axes_dims_rope": [4, 4, 8],
131+
"approximator_in_factor": 32,
132+
"approximator_hidden_dim": 16,
133+
"approximator_layers": 1,
131134
}
132135

133136
inputs_dict = self.dummy_input

0 commit comments

Comments
 (0)