Skip to content

Commit dd08673

Browse files
committed
refactor llava-1.6 preprocessing
1 parent 13e5f59 commit dd08673

File tree

1 file changed

+128
-105
lines changed

1 file changed

+128
-105
lines changed

examples/llava/clip.cpp

Lines changed: 128 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,15 +1696,6 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32
16961696
// set of tools to manupulate images
16971697
// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
16981698
struct image_manipulation {
1699-
static inline int clip(int x, int lower, int upper) {
1700-
return std::max(lower, std::min(x, upper));
1701-
}
1702-
1703-
// Linear interpolation between two points
1704-
static inline float lerp(float s, float e, float t) {
1705-
return s + (e - s) * t;
1706-
}
1707-
17081699
// Bilinear resize function
17091700
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
17101701
dst.nx = target_width;
@@ -1740,6 +1731,8 @@ struct image_manipulation {
17401731
}
17411732
}
17421733

1734+
// Bicubic resize function
1735+
// part of image will be cropped if the aspect ratio is different
17431736
static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
17441737
const int nx = img.nx;
17451738
const int ny = img.ny;
@@ -1804,8 +1797,9 @@ struct image_manipulation {
18041797
}
18051798

18061799
// llava-1.6 type of resize_and_pad
1800+
// if the ratio is not 1:1, padding with fill_color will be applied
18071801
// fill_color is single channel, default is 0 (black)
1808-
static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, uint8_t fill_color = 0) {
1802+
static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array<uint8_t, 3> fill_color = {0, 0, 0}) {
18091803
int target_width = target_resolution.width;
18101804
int target_height = target_resolution.height;
18111805

@@ -1828,7 +1822,14 @@ struct image_manipulation {
18281822
clip_image_u8 padded_image;
18291823
padded_image.nx = target_width;
18301824
padded_image.ny = target_height;
1831-
padded_image.buf.resize(3 * target_width * target_height, fill_color);
1825+
padded_image.buf.resize(3 * target_width * target_height);
1826+
1827+
// Fill the padded image with the fill color
1828+
for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
1829+
padded_image.buf[i] = fill_color[0];
1830+
padded_image.buf[i + 1] = fill_color[1];
1831+
padded_image.buf[i + 2] = fill_color[2];
1832+
}
18321833

18331834
// Calculate padding offsets
18341835
int pad_x = (target_width - new_width) / 2;
@@ -1844,6 +1845,32 @@ struct image_manipulation {
18441845
}
18451846
dst = std::move(padded_image);
18461847
}
1848+
1849+
static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
1850+
dst.nx = w;
1851+
dst.ny = h;
1852+
dst.buf.resize(3 * w * h);
1853+
1854+
for (int i = 0; i < h; ++i) {
1855+
for (int j = 0; j < w; ++j) {
1856+
int src_idx = 3 * ((y + i)*image.nx + (x + j));
1857+
int dst_idx = 3 * (i*w + j);
1858+
dst.buf[dst_idx] = image.buf[src_idx];
1859+
dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
1860+
dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
1861+
}
1862+
}
1863+
}
1864+
1865+
private:
1866+
static inline int clip(int x, int lower, int upper) {
1867+
return std::max(lower, std::min(x, upper));
1868+
}
1869+
1870+
// Linear interpolation between two points
1871+
static inline float lerp(float s, float e, float t) {
1872+
return s + (e - s) * t;
1873+
}
18471874
};
18481875

18491876
/**
@@ -1875,6 +1902,7 @@ struct llava_uhd {
18751902
clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size)
18761903
clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices
18771904
std::vector<slice_coordinates> slices;
1905+
bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6)
18781906
};
18791907

18801908
static int get_max_slices(struct clip_ctx * ctx) {
@@ -1886,47 +1914,79 @@ struct llava_uhd {
18861914

18871915
static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
18881916
slice_instructions res;
1889-
const int patch_size = clip_get_patch_size(ctx);
1890-
const int scale_resolution = clip_get_image_size(ctx);
1891-
const int max_slice_nums = get_max_slices(ctx);
1892-
const int original_width = original_size.width;
1893-
const int original_height = original_size.height;
1917+
const int patch_size = clip_get_patch_size(ctx);
1918+
const int slice_size = clip_get_image_size(ctx);
1919+
const int max_slice_nums = get_max_slices(ctx);
1920+
const int original_width = original_size.width;
1921+
const int original_height = original_size.height;
18941922
const float log_ratio = log((float)original_width / original_height);
1895-
const float ratio = (float)original_width * original_height / (scale_resolution * scale_resolution);
1923+
const float ratio = (float)original_width * original_height / (slice_size * slice_size);
18961924
const int multiple = fmin(ceil(ratio), max_slice_nums);
18971925
const bool has_slices = (multiple > 1);
1926+
const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty();
1927+
1928+
if (has_pinpoints) {
1929+
// has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
1930+
auto refine_size = llava_uhd::select_best_resolution(
1931+
ctx->vision_model.hparams.image_grid_pinpoints,
1932+
original_size);
1933+
res.overview_size = clip_image_size{slice_size, slice_size};
1934+
res.refined_size = refine_size;
1935+
res.grid_size = clip_image_size{0, 0};
1936+
res.padding_refined = true;
1937+
1938+
for (int y = 0; y < refine_size.height; y += slice_size) {
1939+
for (int x = 0; x < refine_size.width; x += slice_size) {
1940+
slice_coordinates slice;
1941+
slice.x = x;
1942+
slice.y = y;
1943+
slice.size.width = std::min(slice_size, refine_size.width - x);
1944+
slice.size.height = std::min(slice_size, refine_size.height - y);
1945+
res.slices.push_back(slice);
1946+
if (x == 0) {
1947+
res.grid_size.width++;
1948+
}
1949+
}
1950+
res.grid_size.height++;
1951+
}
1952+
1953+
return res;
1954+
}
1955+
1956+
// no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
18981957

1899-
auto best_size = get_best_resize(original_size, scale_resolution, patch_size, has_slices);
1958+
auto best_size = get_best_resize(original_size, slice_size, patch_size, has_slices);
19001959
res.overview_size = best_size;
1960+
19011961
if (!has_slices) {
19021962
// skip slicing logic
19031963
res.refined_size = clip_image_size{0, 0};
19041964
res.grid_size = clip_image_size{0, 0};
1905-
return res;
1906-
}
19071965

1908-
auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio);
1909-
auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
1910-
res.grid_size = best_grid;
1911-
res.refined_size = refine_size;
1912-
1913-
int width = refine_size.width;
1914-
int height = refine_size.height;
1915-
int grid_x = int(width / best_grid.width);
1916-
int grid_y = int(height / best_grid.height);
1917-
for (int patches_y = 0, ic = 0;
1918-
patches_y < refine_size.height && ic < best_grid.height;
1919-
patches_y += grid_y, ic += 1) {
1920-
for (int patches_x = 0, jc = 0;
1921-
patches_x < refine_size.width && jc < best_grid.width;
1922-
patches_x += grid_x, jc += 1) {
1923-
slice_coordinates slice;
1924-
slice.x = patches_x;
1925-
slice.y = patches_y;
1926-
slice.size.width = grid_x;
1927-
slice.size.height = grid_y;
1928-
res.slices.push_back(slice);
1929-
// LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
1966+
} else {
1967+
auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio);
1968+
auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
1969+
res.grid_size = best_grid;
1970+
res.refined_size = refine_size;
1971+
1972+
int width = refine_size.width;
1973+
int height = refine_size.height;
1974+
int grid_x = int(width / best_grid.width);
1975+
int grid_y = int(height / best_grid.height);
1976+
for (int patches_y = 0, ic = 0;
1977+
patches_y < refine_size.height && ic < best_grid.height;
1978+
patches_y += grid_y, ic += 1) {
1979+
for (int patches_x = 0, jc = 0;
1980+
patches_x < refine_size.width && jc < best_grid.width;
1981+
patches_x += grid_x, jc += 1) {
1982+
slice_coordinates slice;
1983+
slice.x = patches_x;
1984+
slice.y = patches_y;
1985+
slice.size.width = grid_x;
1986+
slice.size.height = grid_y;
1987+
res.slices.push_back(slice);
1988+
// LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
1989+
}
19301990
}
19311991
}
19321992

@@ -1947,7 +2007,11 @@ struct llava_uhd {
19472007

19482008
// resize to refined size
19492009
clip_image_u8_ptr refined_img(clip_image_u8_init());
1950-
image_manipulation::bicubic_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
2010+
if (inst.padding_refined) {
2011+
image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
2012+
} else {
2013+
image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
2014+
}
19512015

19522016
// create slices
19532017
for (const auto & slice : inst.slices) {
@@ -1957,33 +2021,13 @@ struct llava_uhd {
19572021
int h = slice.size.height;
19582022

19592023
clip_image_u8_ptr img_slice(clip_image_u8_init());
1960-
img_slice->nx = w;
1961-
img_slice->ny = h;
1962-
img_slice->buf.resize(3 * w * h);
1963-
for (int i = 0; i < h; ++i) {
1964-
for (int j = 0; j < w; ++j) {
1965-
int src_idx = 3 * ((y + i)*refined_img->nx + (x + j));
1966-
int dst_idx = 3 * (i*w + j);
1967-
img_slice->buf[dst_idx] = refined_img->buf[src_idx];
1968-
img_slice->buf[dst_idx + 1] = refined_img->buf[src_idx + 1];
1969-
img_slice->buf[dst_idx + 2] = refined_img->buf[src_idx + 2];
1970-
}
1971-
}
2024+
image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
19722025
output.push_back(std::move(img_slice));
19732026
}
19742027

19752028
return output;
19762029
}
19772030

1978-
// used by llava 1.6 with custom list of pinpoints
1979-
static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
1980-
std::vector<clip_image_size> possible_resolutions;
1981-
for (size_t i = 0; i < pinpoints.size(); i += 2) {
1982-
possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
1983-
}
1984-
return select_best_resolution(original_size, possible_resolutions);
1985-
}
1986-
19872031
private:
19882032
static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
19892033
int width = original_size.width;
@@ -2032,6 +2076,15 @@ struct llava_uhd {
20322076
return best_fit;
20332077
}
20342078

2079+
// used by llava 1.6 with custom list of pinpoints
2080+
static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
2081+
std::vector<clip_image_size> possible_resolutions;
2082+
for (size_t i = 0; i < pinpoints.size(); i += 2) {
2083+
possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
2084+
}
2085+
return select_best_resolution(original_size, possible_resolutions);
2086+
}
2087+
20352088
static int ensure_divide(int length, int patch_size) {
20362089
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
20372090
}
@@ -2092,30 +2145,6 @@ struct llava_uhd {
20922145
}
20932146
};
20942147

2095-
// used by llava-1.6, TODO: merge this logic with minicpmv
2096-
static std::vector<clip_image_u8_ptr> divide_to_slices_u8(const clip_image_u8 & image, int slice_size) {
2097-
std::vector<clip_image_u8_ptr> slices;
2098-
int width = image.nx;
2099-
int height = image.ny;
2100-
for (int i = 0; i < height; i += slice_size) {
2101-
for (int j = 0; j < width; j += slice_size) {
2102-
clip_image_u8_ptr patch(clip_image_u8_init());
2103-
patch->nx = std::min(slice_size, width - j);
2104-
patch->ny = std::min(slice_size, height - i);
2105-
patch->buf.resize(3 * patch->nx * patch->ny);
2106-
for (int y = 0; y < patch->ny; ++y) {
2107-
for (int x = 0; x < patch->nx; ++x) {
2108-
for (int c = 0; c < 3; ++c) {
2109-
patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c];
2110-
}
2111-
}
2112-
}
2113-
slices.push_back(std::move(patch));
2114-
}
2115-
}
2116-
return slices;
2117-
}
2118-
21192148
// TODO @ngxson : decprecate the load_image_size singleton pattern
21202149
int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
21212150
const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
@@ -2125,14 +2154,16 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
21252154
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
21262155
// res_imgs memory is being allocated here, previous allocations will be freed if found
21272156
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
2157+
clip_image_size original_size{img->nx, img->ny};
21282158

21292159
if (clip_is_minicpmv(ctx)) {
2130-
auto const inst = llava_uhd::get_slice_instructions(ctx, ctx->load_image_size);
2160+
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
21312161
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
2132-
for (auto & img : imgs) {
2133-
// clip_image_save_to_bmp(*img, "slice_" + std::to_string(i++) + ".bmp");
2162+
2163+
for (size_t i = 0; i < imgs.size(); ++i) {
2164+
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
21342165
clip_image_f32_ptr res(clip_image_f32_init());
2135-
normalize_image_u8_to_f32(*img, *res, ctx->image_mean, ctx->image_std);
2166+
normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
21362167
res_imgs->entries.push_back(std::move(res));
21372168
}
21382169
return true;
@@ -2205,21 +2236,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
22052236
} else {
22062237
if (!params.image_grid_pinpoints.empty()) {
22072238
// "spatial_unpad" with "anyres" processing for llava-1.6
2208-
clip_image_size best_resolution = llava_uhd::select_best_resolution(
2209-
params.image_grid_pinpoints,
2210-
clip_image_size{img->nx, img->ny});
2211-
2212-
image_manipulation::resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
2213-
2214-
std::vector<clip_image_u8_ptr> slices = divide_to_slices_u8(*temp, params.image_size); // prepare spatial sorted main slices of image_size each (336 in llava-1.6)
2239+
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
2240+
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
22152241

2216-
clip_image_u8_ptr image_original_resize(clip_image_u8_init());
2217-
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
2218-
image_manipulation::bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
2219-
slices.insert(slices.begin(), std::move(image_original_resize));
2220-
for (auto & slice : slices) {
2242+
for (size_t i = 0; i < imgs.size(); ++i) {
2243+
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
22212244
clip_image_f32_ptr res(clip_image_f32_init());
2222-
normalize_image_u8_to_f32(*slice, *res, ctx->image_mean, ctx->image_std);
2245+
normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
22232246
res_imgs->entries.push_back(std::move(res));
22242247
}
22252248

0 commit comments

Comments
 (0)