Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions ffcv/libffcv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
from numba import njit
import numpy as np
from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll
from ctypes import CDLL, c_int64, c_uint8, c_uint64, c_float, POINTER, c_void_p, c_uint32, c_bool, cdll
import ffcv._libffcv

lib = CDLL(ffcv._libffcv.__file__)
Expand All @@ -13,10 +13,22 @@
def read(fileno:int, destination:np.ndarray, offset:int):
return read_c(fileno, destination.ctypes.data, destination.size, offset)


ctypes_resize = lib.resize
ctypes_resize.argtypes = 11 * [c_int64]

ctypes_rotate = lib.rotate
ctypes_rotate.argtypes = [c_float, c_int64, c_int64, c_int64, c_int64]

ctypes_shear = lib.shear
ctypes_shear.argtypes = [c_float, c_float, c_int64, c_int64, c_int64, c_int64]

ctypes_add_weighted = lib.add_weighted
ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64]

ctypes_equalize = lib.equalize
ctypes_equalize.argtypes = [c_int64, c_int64, c_int64, c_int64]


def resize_crop(source, start_row, end_row, start_col, end_col, destination):
ctypes_resize(0,
source.ctypes.data,
Expand Down Expand Up @@ -47,5 +59,4 @@ def imdecode(source: np.ndarray, dst: np.ndarray,
ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64]

def memcpy(source: np.ndarray, dest: np.ndarray):
return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size)

return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size)
107 changes: 105 additions & 2 deletions ffcv/transforms/utils/fast_crop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,110 @@
import ctypes
from numba import njit
from numba import njit, prange
import numpy as np
from ...libffcv import ctypes_resize
from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted, ctypes_equalize


"""
Custom equalize -- equivalent to torchvision.transforms.functional.equalize,
but probably slow -- scratch is a (channels, 256) uint16 array.
"""
@njit(parallel=True, fastmath=True, inline='always')
def equalize(source, scratch, destination):
for i in prange(source.shape[-1]):
scratch[i] = np.bincount(source[..., i].flatten(), minlength=256)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunate that np.bincount doesn't have an out argument...

Copy link
Collaborator

@GuillaumeLeclerc GuillaumeLeclerc Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A numba version should be pretty fast and relatively easy to implement no ? (and might even be faster since it would skip the first pass of bincount that checks the min and max values)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. I'll try to add that in the near future.

nonzero_hist = scratch[i][scratch[i] != 0]
step = nonzero_hist[:-1].sum() // 255

if step == 0:
continue

scratch[i][1:] = scratch[i].cumsum()[:-1]
scratch[i] = (scratch[i] + step // 2) // step
scratch[i][0] = 0
np.clip(scratch[i], 0, 255, out=scratch[i])

# numba doesn't like 2d advanced indexing
for row in prange(source.shape[0]):
destination[row, :, i] = scratch[i][source[row, :, i]]

"""
Equalize using OpenCV -- not equivalent to
torchvision.transforms.functional.equalize for so-far-unknown reasons.
"""
@njit(parallel=False, fastmath=True, inline='always')
def fast_equalize(source, chw_scratch, destination):
# this seems kind of hacky
# also, assuming ctypes_equalize allocates a minimal amount of memory
# which may be incorrect -- so maybe we should do this from scratch.
# TODO may be a better way to do this in pure OpenCV
c, h, w = chw_scratch.shape
chw_scratch[0] = source[..., 0]
ctypes_equalize(chw_scratch.ctypes.data,
chw_scratch.ctypes.data,
h, w)
chw_scratch[1] = source[..., 1]
ctypes_equalize(chw_scratch.ctypes.data + h*w,
chw_scratch.ctypes.data + h*w,
h, w)
chw_scratch[2] = source[..., 2]
ctypes_equalize(chw_scratch.ctypes.data + 2*h*w,
chw_scratch.ctypes.data + 2*h*w,
h, w)
destination[..., 0] = chw_scratch[0]
destination[..., 1] = chw_scratch[1]
destination[..., 2] = chw_scratch[2]


@njit(parallel=False, fastmath=True, inline='always')
def invert(source, destination):
destination[:] = 255 - source


@njit(parallel=False, fastmath=True, inline='always')
def solarize(source, threshold, destination):
invert(source, destination)
destination[:] = np.where(source >= threshold, destination, source)


@njit(parallel=False, fastmath=True, inline='always')
def posterize(source, bits, destination):
mask = ~(2 ** (8 - bits) - 1)
destination[:] = source & mask


@njit(inline='always')
def blend(source1, source2, ratio, destination):
ctypes_add_weighted(source1.ctypes.data, ratio,
source2.ctypes.data, 1 - ratio,
destination.ctypes.data,
source1.shape[0], source1.shape[1])


@njit(parallel=False, fastmath=True, inline='always')
def adjust_contrast(source, scratch, factor, destination):
# TODO assuming 3 channels
scratch[:,:,:] = np.mean(0.299 * source[..., 0] +
0.587 * source[..., 1] +
0.114 * source[..., 2])

blend(source, scratch, factor, destination)


@njit(inline='always')
def rotate(source, destination, angle):
ctypes_rotate(angle,
source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])


@njit(inline='always')
def shear(source, destination, shear_x, shear_y):
ctypes_shear(shear_x, shear_y,
source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])


@njit(inline='always')
def resize_crop(source, start_row, end_row, start_col, end_col, destination):
Expand Down
51 changes: 49 additions & 2 deletions libffcv/libffcv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,64 @@ extern "C" {
int64_t start_row, int64_t end_row, int64_t start_col, int64_t end_col,
int64_t dest_p, int64_t tx, int64_t ty) {
// TODO use proper arguments type

cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(tx, ty, CV_8UC3, (uint8_t*) dest_p);
cv::resize(source_matrix.colRange(start_col, end_col).rowRange(start_row, end_row),
dest_matrix, dest_matrix.size(), 0, 0, cv::INTER_AREA);
}

void rotate(float angle, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);
// TODO unsure if this should be sx, sy
cv::Point2f center((sy-1) / 2.0, (sx-1) / 2.0);
cv::Mat rotation = cv::getRotationMatrix2D(center, angle, 1.0);
cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix, rotation, dest_matrix.size(), cv::INTER_NEAREST);
}

void shear(float shear_x, float shear_y, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);

float _shear[6] = { 1, shear_x, 0, shear_y, 1, 0 };

float cx = (sx - 1) / 2.0;
float cy = (sy - 1) / 2.0;

_shear[2] += _shear[0] * -cx + _shear[1] * -cy;
_shear[5] += _shear[3] * -cx + _shear[4] * -cy;

_shear[2] += cx;
_shear[5] += cy;

cv::Mat shear = cv::Mat(2, 3, CV_32F, _shear);
cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix, shear, dest_matrix.size(), cv::INTER_NEAREST);
}

void add_weighted(int64_t img1_p, float a, int64_t img2_p, float b, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat img1(sx, sy, CV_8UC3, (uint8_t*) img1_p);
cv::Mat img2(sx, sy, CV_8UC3, (uint8_t*) img2_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);

// TODO doubt we need colRange/rowRange stuff
cv::addWeighted(img1.colRange(0, sy).rowRange(0, sx), a,
img2.colRange(0, sy).rowRange(0, sx), b,
0, dest_matrix);
}

void equalize(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8U, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8U, (uint8_t*) dest_p);
cv::equalizeHist(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix);
}

void my_memcpy(void *source, void* dst, uint64_t size) {
memcpy(dst, source, size);
}

void my_fread(int64_t fp, int64_t offset, void *destination, int64_t size) {
fseek((FILE *) fp, offset, SEEK_SET);
fread(destination, 1, size, (FILE *) fp);
Expand Down
1 change: 1 addition & 0 deletions tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example_imgs/*
Loading