From d4c4e38e82ca565fa5ce1c404f1cd726ed42121b Mon Sep 17 00:00:00 2001 From: yiming-l21 Date: Fri, 7 Nov 2025 10:06:55 +0800 Subject: [PATCH] feat: implement FLUX.1-Control model. --- xllm/core/framework/batch/dit_batch.cpp | 7 +- .../framework/request/dit_request_params.cpp | 10 + .../framework/request/dit_request_state.h | 2 + xllm/core/runtime/dit_forward_params.h | 2 + xllm/models/dit/autoencoder_kl.h | 33 +- xllm/models/dit/dit.h | 2 +- xllm/models/dit/pipeline_flux_control.h | 369 ++++++++++++++++++ xllm/models/models.h | 35 +- xllm/proto/image_generation.proto | 3 + 9 files changed, 439 insertions(+), 24 deletions(-) create mode 100644 xllm/models/dit/pipeline_flux_control.h mode change 100755 => 100644 xllm/models/models.h diff --git a/xllm/core/framework/batch/dit_batch.cpp b/xllm/core/framework/batch/dit_batch.cpp index 1c09327a..6b192b18 100644 --- a/xllm/core/framework/batch/dit_batch.cpp +++ b/xllm/core/framework/batch/dit_batch.cpp @@ -62,7 +62,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() { std::vector images; std::vector mask_images; - + std::vector control_images; std::vector latents; std::vector masked_image_latents; for (const auto& request : request_vec_) { @@ -96,6 +96,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() { images.emplace_back(input_params.image); mask_images.emplace_back(input_params.mask_image); + control_images.emplace_back(input_params.control_image); } if (input.prompts.size() != request_vec_.size()) { @@ -122,6 +123,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.mask_images = torch::stack(mask_images); } + if (check_tensors_valid(control_images)) { + input.control_image = torch::stack(control_images); + } + if (check_tensors_valid(prompt_embeds)) { input.prompt_embeds = torch::stack(prompt_embeds); } diff --git a/xllm/core/framework/request/dit_request_params.cpp b/xllm/core/framework/request/dit_request_params.cpp index 2d01537f..00b1b5fb 100644 --- a/xllm/core/framework/request/dit_request_params.cpp +++ b/xllm/core/framework/request/dit_request_params.cpp @@ -270,6 +270,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, } } + if (input.has_control_image()) { + std::string raw_bytes; + if (!butil::Base64Decode(input.control_image(), &raw_bytes)) { + LOG(ERROR) << "Base64 control_image decode failed"; + } + if (!decoder.decode(raw_bytes, input_params.control_image)) { + LOG(ERROR) << "Control_image decode failed."; + } + } + // generation params const auto& params = request.parameters(); if (params.has_size()) { diff --git a/xllm/core/framework/request/dit_request_state.h b/xllm/core/framework/request/dit_request_state.h index fab43cb1..7e69bcb6 100644 --- a/xllm/core/framework/request/dit_request_state.h +++ b/xllm/core/framework/request/dit_request_state.h @@ -92,6 +92,8 @@ struct DiTInputParams { torch::Tensor image; + torch::Tensor control_image; + torch::Tensor mask_image; torch::Tensor masked_image_latent; diff --git a/xllm/core/runtime/dit_forward_params.h b/xllm/core/runtime/dit_forward_params.h index e52d4bb2..a96a5118 100644 --- a/xllm/core/runtime/dit_forward_params.h +++ b/xllm/core/runtime/dit_forward_params.h @@ -84,6 +84,8 @@ struct DiTForwardInput { torch::Tensor mask_images; + torch::Tensor control_image; + torch::Tensor masked_image_latents; torch::Tensor prompt_embeds; diff --git a/xllm/models/dit/autoencoder_kl.h b/xllm/models/dit/autoencoder_kl.h index f57acc5b..da5c3023 100644 --- a/xllm/models/dit/autoencoder_kl.h +++ b/xllm/models/dit/autoencoder_kl.h @@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module { bool do_normalize = true, bool do_binarize = false, bool do_convert_rgb = false, - bool do_convert_grayscale = false) { + bool do_convert_grayscale = false, + int64_t latent_channels = 4) { const auto& model_args = context.get_model_args(); + dtype_ = context.get_tensor_options().dtype().toScalarType(); scale_factor_ = 1 << model_args.block_out_channels().size(); - latent_channels_ = 4; + latent_channels_ = latent_channels; do_resize_ = do_resize; do_normalize_ = do_normalize; do_binarize_ = do_binarize; @@ -86,8 +88,29 @@ class VAEImageProcessorImpl : public torch::nn::Module { std::optional width = std::nullopt, const std::string& resize_mode = "default", std::optional> - crop_coords = std::nullopt) { + crop_coords = std::nullopt, + const bool is_pil_image = false) { torch::Tensor processed = image.clone(); + if (is_pil_image == true) { + auto dims = processed.dim(); + if (dims < 2 || dims > 4) { + LOG(FATAL) << "Unsupported PIL image dimension: " << dims; + } + if (dims == 4) { + if (processed.size(1) == 3 || processed.size(1) == 1) { + processed = processed.permute({0, 2, 3, 1}); + } + processed = processed.squeeze(0); + dims = processed.dim(); + } + processed = processed.to(torch::kFloat); + processed = processed / 255.0f; + if (dims == 2) { + processed = processed.unsqueeze(0).unsqueeze(0); + } else { + processed = processed.permute({2, 0, 1}).unsqueeze(0); + } + } if (processed.dtype() != torch::kFloat32) { processed = processed.to(torch::kFloat32); } @@ -116,7 +139,6 @@ class VAEImageProcessorImpl : public torch::nn::Module { if (channel == latent_channels_) { return image; } - auto [target_h, target_w] = get_default_height_width(processed, height, width); if (do_resize_) { @@ -129,7 +151,7 @@ class VAEImageProcessorImpl : public torch::nn::Module { if (do_binarize_) { processed = (processed >= 0.5f).to(torch::kFloat32); } - processed = processed.to(image.dtype()); + processed = processed.to(dtype_); return processed; } @@ -202,6 +224,7 @@ class VAEImageProcessorImpl : public torch::nn::Module { bool do_binarize_ = false; bool do_convert_rgb_ = false; bool do_convert_grayscale_ = false; + torch::ScalarType dtype_ = torch::kFloat32; }; TORCH_MODULE(VAEImageProcessor); diff --git a/xllm/models/dit/dit.h b/xllm/models/dit/dit.h index e9d9302a..d333758d 100644 --- a/xllm/models/dit/dit.h +++ b/xllm/models/dit/dit.h @@ -1436,7 +1436,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module { proj_out_->verify_loaded_weights(prefix + "proj_out."); } - int64_t in_channels() { return out_channels_; } + int64_t in_channels() { return in_channels_; } bool guidance_embeds() { return guidance_embeds_; } private: diff --git a/xllm/models/dit/pipeline_flux_control.h b/xllm/models/dit/pipeline_flux_control.h new file mode 100644 index 00000000..4e0cd8e8 --- /dev/null +++ b/xllm/models/dit/pipeline_flux_control.h @@ -0,0 +1,369 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include "core/layers/pos_embedding.h" +#include "core/layers/rotary_embedding.h" +#include "dit.h" +#include "pipeline_flux_base.h" +// pipeline_flux_control compatible with huggingface weights +// ref to: +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_control.py + +namespace xllm { + +class FluxControlPipelineImpl : public FluxPipelineBaseImpl { + public: + FluxControlPipelineImpl(const DiTModelContext& context) { + auto model_args = context.get_model_args("vae"); + options_ = context.get_tensor_options(); + device_ = options_.device(); + dtype_ = options_.dtype().toScalarType(); + vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); + vae_shift_factor_ = model_args.shift_factor(); + vae_scaling_factor_ = model_args.scale_factor(); + latent_channels_ = model_args.latent_channels(); + + default_sample_size_ = 128; + tokenizer_max_length_ = 77; // TODO: get from config file + LOG(INFO) << "Initializing FluxControl pipeline..."; + image_processor_ = VAEImageProcessor( + context.get_model_context("vae"), true, true, false, false, false); + vae_ = VAE(context.get_model_context("vae")); + LOG(INFO) << "VAE initialized."; + pos_embed_ = register_module( + "pos_embed", + FluxPosEmbed(10000, + context.get_model_args("transformer").axes_dims_rope())); + transformer_ = FluxDiTModel(context.get_model_context("transformer")); + LOG(INFO) << "DiT transformer initialized."; + t5_ = T5EncoderModel(context.get_model_context("text_encoder_2")); + LOG(INFO) << "T5 initialized."; + clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder")); + LOG(INFO) << "CLIP text model initialized."; + scheduler_ = + FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler")); + LOG(INFO) << "FluxControl pipeline initialized."; + register_module("vae", vae_); + LOG(INFO) << "VAE registered."; + register_module("vae_image_processor", image_processor_); + LOG(INFO) << "VAE image processor registered."; + register_module("transformer", transformer_); + LOG(INFO) << "DiT transformer registered."; + register_module("t5", t5_); + LOG(INFO) << "T5 registered."; + register_module("scheduler", scheduler_); + LOG(INFO) << "Scheduler registered."; + register_module("clip_text_model", clip_text_model_); + LOG(INFO) << "CLIP text model registered."; + } + + DiTForwardOutput forward(const DiTForwardInput& input) { + const auto& generation_params = input.generation_params; + int64_t height = generation_params.height; + int64_t width = generation_params.width; + auto seed = generation_params.seed > 0 ? generation_params.seed : 42; + auto prompts = std::make_optional(input.prompts); + auto prompts_2 = input.prompts_2.empty() + ? std::nullopt + : std::make_optional(input.prompts_2); + + auto control_image = input.control_image; + + auto latents = input.latents.defined() ? std::make_optional(input.latents) + : std::nullopt; + auto prompt_embeds = input.prompt_embeds.defined() + ? std::make_optional(input.prompt_embeds) + : std::nullopt; + auto pooled_prompt_embeds = + input.pooled_prompt_embeds.defined() + ? std::make_optional(input.pooled_prompt_embeds) + : std::nullopt; + + std::vector output = + forward_(prompts, + prompts_2, + control_image, + height, + width, + generation_params.strength, + generation_params.num_inference_steps, + generation_params.guidance_scale, + generation_params.num_images_per_prompt, + seed, + latents, + prompt_embeds, + pooled_prompt_embeds, + generation_params.max_sequence_length); + + DiTForwardOutput out; + out.tensors = torch::chunk(output[0], output[0].size(0), 0); + LOG(INFO) << "Output tensor chunks size: " << out.tensors.size(); + return out; + } + + void load_model(std::unique_ptr loader) { + LOG(INFO) << "FluxControlPipeline loading model from" + << loader->model_root_path(); + std::string model_path = loader->model_root_path(); + auto transformer_loader = loader->take_component_loader("transformer"); + auto vae_loader = loader->take_component_loader("vae"); + auto t5_loader = loader->take_component_loader("text_encoder_2"); + auto clip_loader = loader->take_component_loader("text_encoder"); + auto tokenizer_loader = loader->take_component_loader("tokenizer"); + auto tokenizer_2_loader = loader->take_component_loader("tokenizer_2"); + LOG(INFO) + << "FluxControl model components loaded, start to load weights to " + "sub models"; + transformer_->load_model(std::move(transformer_loader)); + transformer_->to(device_); + vae_->load_model(std::move(vae_loader)); + vae_->to(device_); + t5_->load_model(std::move(t5_loader)); + t5_->to(device_); + clip_text_model_->load_model(std::move(clip_loader)); + clip_text_model_->to(device_); + tokenizer_ = tokenizer_loader->tokenizer(); + tokenizer_2_ = tokenizer_2_loader->tokenizer(); + } + + private: + torch::Tensor encode_vae_image(const torch::Tensor& image, int64_t seed) { + torch::Tensor latents = vae_->encode(image, seed); + latents = (latents - vae_shift_factor_) * vae_scaling_factor_; + return latents; + } + + std::pair get_timesteps(int64_t num_inference_steps, + float strength) { + int64_t init_timestep = + std::min(static_cast(num_inference_steps * strength), + num_inference_steps); + + int64_t t_start = std::max(num_inference_steps - init_timestep, int64_t(0)); + int64_t start_idx = t_start * scheduler_->order(); + auto timesteps = + scheduler_->timesteps().slice(0, start_idx).to(device_).to(dtype_); + scheduler_->set_begin_index(start_idx); + return {timesteps, num_inference_steps - t_start}; + } + + std::pair prepare_latents( + int64_t batch_size, + int64_t num_channels_latents, + int64_t height, + int64_t width, + int64_t seed, + std::optional latents = std::nullopt) { + int64_t adjusted_height = 2 * (height / (vae_scale_factor_ * 2)); + int64_t adjusted_width = 2 * (width / (vae_scale_factor_ * 2)); + std::vector shape = { + batch_size, num_channels_latents, adjusted_height, adjusted_width}; + if (latents.has_value()) { + torch::Tensor latent_image_ids = prepare_latent_image_ids( + batch_size, adjusted_height / 2, adjusted_width / 2); + return {latents.value(), latent_image_ids}; + } + torch::Tensor latents_tensor = randn_tensor(shape, seed, options_); + torch::Tensor packed_latents = pack_latents(latents_tensor, + batch_size, + num_channels_latents, + adjusted_height, + adjusted_width); + torch::Tensor latent_image_ids = prepare_latent_image_ids( + batch_size, adjusted_height / 2, adjusted_width / 2); + return {packed_latents, latent_image_ids}; + } + torch::Tensor prepare_image(torch::Tensor image, + int64_t width, + int64_t height, + int64_t batch_size, + int64_t num_images_per_prompt) { + int image_batch_size = image.size(0); + int repeat_times; + image = image_processor_->preprocess( + image, height, width, "default", std::nullopt, true); + if (image_batch_size == 1) { + repeat_times = batch_size; + } else { + repeat_times = num_images_per_prompt; + } + const auto B = image.size(0); + const auto C = image.size(1); + const auto H = image.size(2); + const auto W = image.size(3); + image = image.unsqueeze(1) + .repeat({1, repeat_times, 1, 1, 1}) + .reshape({B * repeat_times, C, H, W}) + .to(device_) + .to(dtype_); + return image; + } + + std::vector forward_( + std::optional> prompt = std::nullopt, + std::optional> prompt_2 = std::nullopt, + torch::Tensor control_image = torch::Tensor(), + int64_t height = 512, + int64_t width = 512, + float strength = 1.0f, + int64_t num_inference_steps = 50, + float guidance_scale = 30.0f, + int64_t num_images_per_prompt = 1, + int64_t seed = 42, + std::optional latents = std::nullopt, + std::optional prompt_embeds = std::nullopt, + std::optional pooled_prompt_embeds = std::nullopt, + int64_t max_sequence_length = 512) { + torch::NoGradGuard no_grad; + int64_t actual_height = height; + int64_t actual_width = width; + int64_t batch_size; + if (prompt.has_value()) { + batch_size = prompt.value().size(); + } else { + batch_size = prompt_embeds.value().size(0); + } + int64_t total_batch_size = batch_size * num_images_per_prompt; + // encode prompt + auto [encoded_prompt_embeds, encoded_pooled_embeds, text_ids] = + encode_prompt(prompt, + prompt_2, + prompt_embeds, + pooled_prompt_embeds, + num_images_per_prompt, + max_sequence_length); + + // prepare latent + int64_t num_channels_latents = transformer_->in_channels() / 8; + // control image to latents + control_image = prepare_image(control_image, + width, + height, + batch_size * num_images_per_prompt, + num_images_per_prompt); + if (control_image.dim() == 4) { + auto enc = vae_->encode(control_image, seed); + control_image = (enc - vae_shift_factor_) * vae_scaling_factor_; + control_image = control_image.to(device_).to(dtype_); + auto shape = control_image.sizes(); + auto height_control_image = shape[2]; + auto width_control_image = shape[3]; + control_image = pack_latents(control_image, + total_batch_size, + num_channels_latents, + height_control_image, + width_control_image); + } + auto [prepared_latents, latent_image_ids] = + prepare_latents(total_batch_size, + num_channels_latents, + actual_height, + actual_width, + seed, + latents); + // prepare timestep + std::vector new_sigmas; + for (int64_t i = 0; i < num_inference_steps; ++i) { + new_sigmas.push_back(1.0f - static_cast(i) / + (num_inference_steps - 1) * + (1.0f - 1.0f / num_inference_steps)); + } + + int64_t image_seq_len = prepared_latents.size(1); + float mu = calculate_shift(image_seq_len, + scheduler_->base_image_seq_len(), + scheduler_->max_image_seq_len(), + scheduler_->base_shift(), + scheduler_->max_shift()); + auto [timesteps, num_inference_steps_actual] = retrieve_timesteps( + scheduler_, num_inference_steps, device_, new_sigmas, mu); + int64_t num_warmup_steps = + std::max(static_cast(timesteps.numel()) - + num_inference_steps_actual * scheduler_->order(), + static_cast(0LL)); + // prepare guidance + torch::Tensor guidance; + if (transformer_->guidance_embeds()) { + torch::TensorOptions options = + torch::dtype(torch::kFloat32).device(device_); + + guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options); + guidance = guidance.expand({prepared_latents.size(0)}); + } + scheduler_->set_begin_index(0); + torch::Tensor timestep = + torch::empty({prepared_latents.size(0)}, prepared_latents.options()); + // image rotary positional embeddings outplace computation + auto [rot_emb1, rot_emb2] = + pos_embed_->forward_cache(text_ids, + latent_image_ids, + height / (vae_scale_factor_ * 2), + width / (vae_scale_factor_ * 2)); + torch::Tensor image_rotary_emb = torch::stack({rot_emb1, rot_emb2}, 0); + for (int64_t i = 0; i < timesteps.numel(); ++i) { + torch::Tensor t = timesteps[i].unsqueeze(0); + timestep.fill_(t.item()) + .to(prepared_latents.dtype()) + .div_(1000.0f); + int64_t step_id = i + 1; + auto controlled_latents = + torch::cat({prepared_latents, control_image}, 2); + torch::Tensor noise_pred = transformer_->forward(controlled_latents, + encoded_prompt_embeds, + encoded_pooled_embeds, + timestep, + image_rotary_emb, + guidance, + step_id); + auto prev_latents = scheduler_->step(noise_pred, t, prepared_latents); + prepared_latents = prev_latents.detach(); + std::vector tensors = {prepared_latents, noise_pred}; + noise_pred.reset(); + prev_latents = torch::Tensor(); + + if (latents.has_value() && + prepared_latents.dtype() != latents.value().dtype()) { + prepared_latents = prepared_latents.to(latents.value().dtype()); + } + } + torch::Tensor image; + // Unpack latents + torch::Tensor unpacked_latents = unpack_latents( + prepared_latents, actual_height, actual_width, vae_scale_factor_); + unpacked_latents = + (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_; + unpacked_latents = unpacked_latents.to(dtype_); + image = vae_->decode(unpacked_latents); + image = image_processor_->postprocess(image, "pil"); + return std::vector{{image}}; + } + + private: + FlowMatchEulerDiscreteScheduler scheduler_{nullptr}; + VAE vae_{nullptr}; + VAEImageProcessor image_processor_{nullptr}; + FluxDiTModel transformer_{nullptr}; + float vae_scaling_factor_; + float vae_shift_factor_; + int64_t vae_latent_channels_; + int default_sample_size_; + int64_t latent_channels_; + FluxPosEmbed pos_embed_{nullptr}; +}; +TORCH_MODULE(FluxControlPipeline); + +REGISTER_DIT_MODEL(flux_control, FluxControlPipeline); +} // namespace xllm diff --git a/xllm/models/models.h b/xllm/models/models.h old mode 100755 new mode 100644 index 12e8e2d5..161a4c31 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -16,23 +16,24 @@ limitations under the License. #pragma once #if defined(USE_NPU) -#include "dit/pipeline_flux.h" // IWYU pragma: keep -#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep -#include "llm/deepseek_v2.h" // IWYU pragma: keep -#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep -#include "llm/deepseek_v3.h" // IWYU pragma: keep -#include "llm/glm4_moe.h" // IWYU pragma: keep -#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep -#include "llm/kimi_k2.h" // IWYU pragma: keep -#include "llm/llama.h" // IWYU pragma: keep -#include "llm/llama3.h" // IWYU pragma: keep -#include "llm/llm_model_base.h" // IWYU pragma: keep -#include "llm/qwen2.h" // IWYU pragma: keep -#include "llm/qwen3_embedding.h" // IWYU pragma: keep -#include "vlm/minicpmv.h" // IWYU pragma: keep -#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep -#include "vlm/qwen3_vl.h" // IWYU pragma: keep -#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep +#include "dit/pipeline_flux.h" // IWYU pragma: keep +#include "dit/pipeline_flux_control.h" // IWYU pragma: keep +#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep +#include "llm/deepseek_v2.h" // IWYU pragma: keep +#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep +#include "llm/deepseek_v3.h" // IWYU pragma: keep +#include "llm/glm4_moe.h" // IWYU pragma: keep +#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep +#include "llm/kimi_k2.h" // IWYU pragma: keep +#include "llm/llama.h" // IWYU pragma: keep +#include "llm/llama3.h" // IWYU pragma: keep +#include "llm/llm_model_base.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3_embedding.h" // IWYU pragma: keep +#include "vlm/minicpmv.h" // IWYU pragma: keep +#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #endif #include "llm/llm_model_base.h" // IWYU pragma: keep diff --git a/xllm/proto/image_generation.proto b/xllm/proto/image_generation.proto index f16b9f17..cf8eb69b 100644 --- a/xllm/proto/image_generation.proto +++ b/xllm/proto/image_generation.proto @@ -43,6 +43,9 @@ message Input { // An image batch of mask images generated by the VAE optional Tensor masked_image_latent = 12; + + // Control Image + optional string control_image = 13; } // Generation parameters container