Skip to content

Commit df9a55d

Browse files
committed
feat: implement FLUX.1-Control model.
1 parent 3028d49 commit df9a55d

File tree

9 files changed

+449
-23
lines changed

9 files changed

+449
-23
lines changed

xllm/core/framework/batch/dit_batch.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
6262

6363
std::vector<torch::Tensor> images;
6464
std::vector<torch::Tensor> mask_images;
65-
65+
std::vector<torch::Tensor> control_images;
6666
std::vector<torch::Tensor> latents;
6767
std::vector<torch::Tensor> masked_image_latents;
6868
for (const auto& request : request_vec_) {
@@ -96,6 +96,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
9696

9797
images.emplace_back(input_params.image);
9898
mask_images.emplace_back(input_params.mask_image);
99+
control_images.emplace_back(input_params.control_image);
99100
}
100101

101102
if (input.prompts.size() != request_vec_.size()) {
@@ -122,6 +123,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
122123
input.mask_images = torch::stack(mask_images);
123124
}
124125

126+
if (check_tensors_valid(control_images)) {
127+
input.control_image = torch::stack(control_images);
128+
}
129+
125130
if (check_tensors_valid(prompt_embeds)) {
126131
input.prompt_embeds = torch::stack(prompt_embeds);
127132
}

xllm/core/framework/request/dit_request_params.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
270270
}
271271
}
272272

273+
if (input.has_control_image()) {
274+
std::string raw_bytes;
275+
if (!butil::Base64Decode(input.control_image(), &raw_bytes)) {
276+
LOG(ERROR) << "Base64 control_image decode failed";
277+
}
278+
if (!decoder.decode(raw_bytes, input_params.control_image)) {
279+
LOG(ERROR) << "Control_image decode failed.";
280+
}
281+
}
282+
273283
// generation params
274284
const auto& params = request.parameters();
275285
if (params.has_size()) {

xllm/core/framework/request/dit_request_state.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ struct DiTInputParams {
9292

9393
torch::Tensor image;
9494

95+
torch::Tensor control_image;
96+
9597
torch::Tensor mask_image;
9698

9799
torch::Tensor masked_image_latent;

xllm/core/runtime/dit_forward_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ struct DiTForwardInput {
8484

8585
torch::Tensor mask_images;
8686

87+
torch::Tensor control_image;
88+
8789
torch::Tensor masked_image_latents;
8890

8991
torch::Tensor prompt_embeds;

xllm/models/dit/autoencoder_kl.h

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
6262
bool do_normalize = true,
6363
bool do_binarize = false,
6464
bool do_convert_rgb = false,
65-
bool do_convert_grayscale = false) {
65+
bool do_convert_grayscale = false,
66+
int64_t latent_channels = 4) {
6667
const auto& model_args = context.get_model_args();
68+
dtype_ = context.get_tensor_options().dtype().toScalarType();
6769
scale_factor_ = 1 << model_args.block_out_channels().size();
68-
latent_channels_ = 4;
70+
latent_channels_ = latent_channels;
6971
do_resize_ = do_resize;
7072
do_normalize_ = do_normalize;
7173
do_binarize_ = do_binarize;
@@ -86,8 +88,29 @@ class VAEImageProcessorImpl : public torch::nn::Module {
8688
std::optional<int64_t> width = std::nullopt,
8789
const std::string& resize_mode = "default",
8890
std::optional<std::tuple<int64_t, int64_t, int64_t, int64_t>>
89-
crop_coords = std::nullopt) {
91+
crop_coords = std::nullopt,
92+
const bool is_pil_image = false) {
9093
torch::Tensor processed = image.clone();
94+
if (is_pil_image == true) {
95+
auto dims = processed.dim();
96+
if (dims < 2 || dims > 4) {
97+
LOG(FATAL) << "Unsupported PIL image dimension: " << dims;
98+
}
99+
if (dims == 4) {
100+
if (processed.size(1) == 3 || processed.size(1) == 1) {
101+
processed = processed.permute({0, 2, 3, 1});
102+
}
103+
processed = processed.squeeze(0);
104+
dims = processed.dim();
105+
}
106+
processed = processed.to(torch::kFloat);
107+
processed = processed / 255.0f;
108+
if (dims == 2) {
109+
processed = processed.unsqueeze(0).unsqueeze(0);
110+
} else {
111+
processed = processed.permute({2, 0, 1}).unsqueeze(0);
112+
}
113+
}
91114
if (processed.dtype() != torch::kFloat32) {
92115
processed = processed.to(torch::kFloat32);
93116
}
@@ -116,7 +139,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
116139
if (channel == latent_channels_) {
117140
return image;
118141
}
119-
120142
auto [target_h, target_w] =
121143
get_default_height_width(processed, height, width);
122144
if (do_resize_) {
@@ -129,7 +151,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
129151
if (do_binarize_) {
130152
processed = (processed >= 0.5f).to(torch::kFloat32);
131153
}
132-
processed = processed.to(image.dtype());
154+
processed = processed.to(dtype_);
133155
return processed;
134156
}
135157

@@ -202,6 +224,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
202224
bool do_binarize_ = false;
203225
bool do_convert_rgb_ = false;
204226
bool do_convert_grayscale_ = false;
227+
torch::ScalarType dtype_ = torch::kFloat32;
205228
};
206229
TORCH_MODULE(VAEImageProcessor);
207230

xllm/models/dit/dit.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module {
14361436
proj_out_->verify_loaded_weights(prefix + "proj_out.");
14371437
}
14381438

1439-
int64_t in_channels() { return out_channels_; }
1439+
int64_t in_channels() { return in_channels_; }
14401440
bool guidance_embeds() { return guidance_embeds_; }
14411441

14421442
private:

0 commit comments

Comments
 (0)