Skip to content

Commit 74fe45e

Browse files
committed
update chroma transformer approximator init params
1 parent 35dc65b commit 74fe45e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def __init__(
416416
num_attention_heads: int = 24,
417417
joint_attention_dim: int = 4096,
418418
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
419-
approximator_in_factor: int = 16,
419+
approximator_num_channels: int = 64,
420420
approximator_hidden_dim: int = 5120,
421421
approximator_layers: int = 5,
422422
):
@@ -427,11 +427,11 @@ def __init__(
427427
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
428428

429429
self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
430-
num_channels=approximator_in_factor,
430+
num_channels=approximator_num_channels // 4,
431431
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
432432
)
433433
self.distilled_guidance_layer = ChromaApproximator(
434-
in_dim=64,
434+
in_dim=approximator_num_channels,
435435
out_dim=self.inner_dim,
436436
hidden_dim=approximator_hidden_dim,
437437
n_layers=approximator_layers,

tests/models/transformers/test_models_transformer_chroma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def prepare_init_args_and_inputs_for_common(self):
128128
"num_attention_heads": 2,
129129
"joint_attention_dim": 32,
130130
"axes_dims_rope": [4, 4, 8],
131-
"approximator_in_factor": 32,
131+
"approximator_num_channels": 8,
132132
"approximator_hidden_dim": 16,
133133
"approximator_layers": 1,
134134
}

0 commit comments

Comments
 (0)