Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion xllm/core/framework/batch/dit_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {

std::vector<torch::Tensor> images;
std::vector<torch::Tensor> mask_images;

std::vector<torch::Tensor> control_images;
std::vector<torch::Tensor> latents;
std::vector<torch::Tensor> masked_image_latents;
for (const auto& request : request_vec_) {
Expand Down Expand Up @@ -96,6 +96,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {

images.emplace_back(input_params.image);
mask_images.emplace_back(input_params.mask_image);
control_images.emplace_back(input_params.control_image);
}

if (input.prompts.size() != request_vec_.size()) {
Expand All @@ -122,6 +123,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
input.mask_images = torch::stack(mask_images);
}

if (check_tensors_valid(control_images)) {
input.control_image = torch::stack(control_images);
}

if (check_tensors_valid(prompt_embeds)) {
input.prompt_embeds = torch::stack(prompt_embeds);
}
Expand Down
10 changes: 10 additions & 0 deletions xllm/core/framework/request/dit_request_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
}
}

if (input.has_control_image()) {
std::string raw_bytes;
if (!butil::Base64Decode(input.control_image(), &raw_bytes)) {
LOG(ERROR) << "Base64 control_image decode failed";
}
if (!decoder.decode(raw_bytes, input_params.control_image)) {
LOG(ERROR) << "Control_image decode failed.";
}
}

// generation params
const auto& params = request.parameters();
if (params.has_size()) {
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/request/dit_request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ struct DiTInputParams {

torch::Tensor image;

torch::Tensor control_image;

torch::Tensor mask_image;

torch::Tensor masked_image_latent;
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/runtime/dit_forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ struct DiTForwardInput {

torch::Tensor mask_images;

torch::Tensor control_image;

torch::Tensor masked_image_latents;

torch::Tensor prompt_embeds;
Expand Down
33 changes: 28 additions & 5 deletions xllm/models/dit/autoencoder_kl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
bool do_normalize = true,
bool do_binarize = false,
bool do_convert_rgb = false,
bool do_convert_grayscale = false) {
bool do_convert_grayscale = false,
int64_t latent_channels = 4) {
const auto& model_args = context.get_model_args();
dtype_ = context.get_tensor_options().dtype().toScalarType();
scale_factor_ = 1 << model_args.block_out_channels().size();
latent_channels_ = 4;
latent_channels_ = latent_channels;
do_resize_ = do_resize;
do_normalize_ = do_normalize;
do_binarize_ = do_binarize;
Expand All @@ -86,8 +88,29 @@ class VAEImageProcessorImpl : public torch::nn::Module {
std::optional<int64_t> width = std::nullopt,
const std::string& resize_mode = "default",
std::optional<std::tuple<int64_t, int64_t, int64_t, int64_t>>
crop_coords = std::nullopt) {
crop_coords = std::nullopt,
const bool is_pil_image = false) {
torch::Tensor processed = image.clone();
if (is_pil_image == true) {
auto dims = processed.dim();
if (dims < 2 || dims > 4) {
LOG(FATAL) << "Unsupported PIL image dimension: " << dims;
}
if (dims == 4) {
if (processed.size(1) == 3 || processed.size(1) == 1) {
processed = processed.permute({0, 2, 3, 1});
}
processed = processed.squeeze(0);
dims = processed.dim();
}
processed = processed.to(torch::kFloat);
processed = processed / 255.0f;
if (dims == 2) {
processed = processed.unsqueeze(0).unsqueeze(0);
} else {
processed = processed.permute({2, 0, 1}).unsqueeze(0);
}
}
if (processed.dtype() != torch::kFloat32) {
processed = processed.to(torch::kFloat32);
}
Expand Down Expand Up @@ -116,7 +139,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
if (channel == latent_channels_) {
return image;
}

auto [target_h, target_w] =
get_default_height_width(processed, height, width);
if (do_resize_) {
Expand All @@ -129,7 +151,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
if (do_binarize_) {
processed = (processed >= 0.5f).to(torch::kFloat32);
}
processed = processed.to(image.dtype());
processed = processed.to(dtype_);
return processed;
}

Expand Down Expand Up @@ -202,6 +224,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
bool do_binarize_ = false;
bool do_convert_rgb_ = false;
bool do_convert_grayscale_ = false;
torch::ScalarType dtype_ = torch::kFloat32;
};
TORCH_MODULE(VAEImageProcessor);

Expand Down
2 changes: 1 addition & 1 deletion xllm/models/dit/dit.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module {
proj_out_->verify_loaded_weights(prefix + "proj_out.");
}

int64_t in_channels() { return out_channels_; }
int64_t in_channels() { return in_channels_; }
bool guidance_embeds() { return guidance_embeds_; }

private:
Expand Down
Loading