Skip to content

Commit 93ed721

Browse files
committed
is_chroma
1 parent ad39011 commit 93ed721

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

flux.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ namespace Flux {
603603
bool qkv_bias = true;
604604
bool guidance_embed = true;
605605
bool flash_attn = true;
606-
bool chroma_guidance = false;
606+
bool is_chroma = false;
607607
};
608608

609609
struct Flux : public GGMLBlock {
@@ -746,7 +746,7 @@ namespace Flux {
746746
int64_t pe_dim = params.hidden_size / params.num_heads;
747747

748748
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
749-
if (params.chroma_guidance) {
749+
if (params.is_chroma) {
750750
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
751751
} else {
752752
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
@@ -764,7 +764,7 @@ namespace Flux {
764764
i,
765765
params.qkv_bias,
766766
params.flash_attn,
767-
params.chroma_guidance));
767+
params.is_chroma));
768768
}
769769

770770
for (int i = 0; i < params.depth_single_blocks; i++) {
@@ -774,11 +774,11 @@ namespace Flux {
774774
i,
775775
0.f,
776776
params.flash_attn,
777-
params.chroma_guidance));
777+
params.is_chroma));
778778
}
779779

780780
// TODO: no modulation for chroma
781-
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels, params.chroma_guidance));
781+
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma));
782782
}
783783

784784
struct ggml_tensor* patchify(struct ggml_context* ctx,
@@ -842,7 +842,7 @@ namespace Flux {
842842

843843
img = img_in->forward(ctx, img);
844844
struct ggml_tensor* vec;
845-
if (params.chroma_guidance) {
845+
if (params.is_chroma) {
846846
int64_t mod_index_length = 344;
847847
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
848848
auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f);
@@ -915,7 +915,6 @@ namespace Flux {
915915
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
916916

917917
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
918-
919918
return img;
920919
}
921920

@@ -1004,7 +1003,7 @@ namespace Flux {
10041003
}
10051004
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
10061005
// Chroma
1007-
flux_params.chroma_guidance = true;
1006+
flux_params.is_chroma = true;
10081007
}
10091008
size_t db = tensor_name.find("double_blocks.");
10101009
if (db != std::string::npos) {
@@ -1025,7 +1024,7 @@ namespace Flux {
10251024
}
10261025

10271026
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
1028-
if (flux_params.chroma_guidance) {
1027+
if (flux_params.is_chroma) {
10291028
LOG_INFO("Using pruned modulation (Chroma)");
10301029
} else if (!flux_params.guidance_embed) {
10311030
LOG_INFO("Flux guidance is disabled (Schnell mode)");
@@ -1058,10 +1057,10 @@ namespace Flux {
10581057
if (c_concat != NULL) {
10591058
c_concat = to_backend(c_concat);
10601059
}
1061-
if (!flux_params.chroma_guidance) {
1060+
if (!flux_params.is_chroma) {
10621061
y = to_backend(y);
10631062
} else {
1064-
// ggml_arrange is not working on some backends, so let's reuse y to precompute it
1063+
// ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
10651064
std::vector<float> range = arange(0, 344);
10661065
y = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size());
10671066
set_backend_tensor_data(y, range.data());

0 commit comments

Comments
 (0)