Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ int main(int argc, const char* argv[]) {
}
}

if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
if (params.control_image_path.size() > 0) {
int width = 0;
int height = 0;
control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height);
Expand Down
34 changes: 32 additions & 2 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ namespace Flux {
bool guidance_embed = true;
bool flash_attn = true;
bool is_chroma = false;
SDVersion version = VERSION_FLUX;
};

struct Flux : public GGMLBlock {
Expand Down Expand Up @@ -720,6 +721,7 @@ namespace Flux {
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);

img = img_in->forward(ctx, img);

struct ggml_tensor* vec;
struct ggml_tensor* txt_img_mask = NULL;
if (params.is_chroma) {
Expand Down Expand Up @@ -849,14 +851,36 @@ namespace Flux {
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];

if (c_concat != NULL) {
if (params.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = process_img(ctx, masked);
mask = process_img(ctx, mask);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != NULL);

ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
control = patchify(ctx, control, patch_size);
img = ggml_concat(ctx, img, control, 0);
}

if (ref_latents.size() > 0) {
Expand All @@ -867,6 +891,7 @@ namespace Flux {
}

auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]

if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
Expand Down Expand Up @@ -896,13 +921,18 @@ namespace Flux {
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend, offload_params_to_cpu), use_mask(use_mask) {
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
flux_params.version = version;
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
Expand Down
14 changes: 10 additions & 4 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,18 +428,24 @@ __STATIC_INLINE__ void sd_image_to_tensor(sd_image_t image,

__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
struct ggml_tensor* output,
float masked_value = 0.5f) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
float rescale_mx = mask->ne[0] / output->ne[0];
float rescale_my = mask->ne[1] / output->ne[1];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
int mx = (int)(ix * rescale_mx);
int my = (int)(iy * rescale_my);
float m = ggml_tensor_get_f32(mask, mx, my);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
ggml_tensor_set_f32(mask, m, mx, my);
for (int k = 0; k < channels; k++) {
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
value = (1 - m) * (value - masked_value) + masked_value;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
Expand Down
9 changes: 7 additions & 2 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1803,10 +1803,15 @@ SDVersion ModelLoader::get_sd_version() {
}

if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}

Expand Down
12 changes: 9 additions & 3 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ enum SDVersion {
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_FLUX_CONTROLS,
VERSION_FLEX_2,
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
Expand Down Expand Up @@ -66,7 +68,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
}

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
return true;
}
return false;
Expand All @@ -80,7 +82,7 @@ static inline bool sd_version_is_wan(SDVersion version) {
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
return true;
}
return false;
Expand All @@ -97,8 +99,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
}

static inline bool sd_version_is_control(SDVersion version) {
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
}

static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version) || sd_version_is_control(version);
}

enum PMVersion {
Expand Down
Loading
Loading