Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ffcv/fields/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
import json
from dataclasses import replace
from kornia import warnings

import numpy as np
import torch as ch
Expand Down
21 changes: 18 additions & 3 deletions ffcv/libffcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numba import njit
import numpy as np
import platform
from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll
from ctypes import CDLL, c_int64, c_uint8, c_uint64, c_float, POINTER, c_void_p, c_uint32, c_bool, cdll
import ffcv._libffcv

lib = CDLL(ffcv._libffcv.__file__)
Expand All @@ -22,6 +22,22 @@ def read(fileno:int, destination:np.ndarray, offset:int):
ctypes_resize = lib.resize
ctypes_resize.argtypes = 11 * [c_int64]

ctypes_rotate = lib.rotate
ctypes_rotate.argtypes = [c_float, c_int64, c_int64, c_int64, c_int64]

ctypes_shear = lib.shear
ctypes_shear.argtypes = [c_float, c_float, c_int64, c_int64, c_int64, c_int64]

ctypes_add_weighted = lib.add_weighted
ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64]

ctypes_equalize = lib.equalize
ctypes_equalize.argtypes = 4 * [c_int64]

ctypes_unsharp_mask = lib.unsharp_mask
ctypes_unsharp_mask.argtypes = 4 * [c_int64]


def resize_crop(source, start_row, end_row, start_col, end_col, destination):
ctypes_resize(0,
source.ctypes.data,
Expand Down Expand Up @@ -52,5 +68,4 @@ def imdecode(source: np.ndarray, dst: np.ndarray,
ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64]

def memcpy(source: np.ndarray, dest: np.ndarray):
return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize)

return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize)
1 change: 0 additions & 1 deletion ffcv/pipeline/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
import ast

import astor
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Set
from abc import ABC, abstractmethod
Expand Down
183 changes: 181 additions & 2 deletions ffcv/transforms/utils/fast_crop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,186 @@
import ctypes
from numba import njit
from numba import njit, prange
import numpy as np
from ...libffcv import ctypes_resize
from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, \
ctypes_add_weighted, ctypes_equalize, ctypes_unsharp_mask

"""
Requires a float32 scratch array
"""
@njit(parallel=True, fastmath=True, inline='always')
def autocontrast(source, scratchf, destination):
# numba: no kwargs in min? as a consequence, I might as well have written
# this in C++
# TODO assuming 3 channels
minimum = [source[..., 0].min(), source[..., 1].min(), source[..., 2].min()]
maximum = [source[..., 0].max(), source[..., 1].max(), source[..., 2].max()]
scale = [0.0, 0.0, 0.0]
for i in prange(source.shape[-1]):
if minimum[i] == maximum[i]:
scale[i] = 1
minimum[i] = 0
else:
scale[i] = 255. / (maximum[i] - minimum[i])
for i in prange(source.shape[-1]):
scratchf[..., i] = source[..., i] - minimum[i]
scratchf[..., i] = scratchf[..., i] * scale[i]
np.clip(scratchf, 0, 255, out=scratchf)
destination[:] = scratchf


"""
Custom equalize -- equivalent to torchvision.transforms.functional.equalize,
but probably slow -- scratch is a (channels, 256) uint16 array.
"""
@njit(parallel=True, fastmath=True, inline='always')
def equalize(source, scratch, destination):
for i in prange(source.shape[-1]):
# TODO memory less than ideal for bincount() and hist()
scratch[i] = np.bincount(source[..., i].flatten(), minlength=256)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunate that np.bincount doesn't have an out argument...

Copy link
Collaborator

@GuillaumeLeclerc GuillaumeLeclerc Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A numba version should be pretty fast and relatively easy to implement no ? (and might even be faster since it would skip the first pass of bincount that checks the min and max values)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. I'll try to add that in the near future.

nonzero_hist = scratch[i][scratch[i] != 0]
step = nonzero_hist[:-1].sum() // 255

if step == 0:
continue

scratch[i][1:] = scratch[i].cumsum()[:-1]
scratch[i] = (scratch[i] + step // 2) // step
scratch[i][0] = 0
np.clip(scratch[i], 0, 255, out=scratch[i])

# numba doesn't like 2d advanced indexing
for row in prange(source.shape[0]):
destination[row, :, i] = scratch[i][source[row, :, i]]

"""
Equalize using OpenCV -- not equivalent to
torchvision.transforms.functional.equalize for so-far-unknown reasons.
"""
@njit(parallel=False, fastmath=True, inline='always')
def fast_equalize(source, chw_scratch, destination):
# this seems kind of hacky
# also, assuming ctypes_equalize allocates a minimal amount of memory
# which may be incorrect -- so maybe we should do this from scratch.
# TODO may be a better way to do this in pure OpenCV
c, h, w = chw_scratch.shape
chw_scratch[0] = source[..., 0]
ctypes_equalize(chw_scratch.ctypes.data,
chw_scratch.ctypes.data,
h, w)
chw_scratch[1] = source[..., 1]
ctypes_equalize(chw_scratch.ctypes.data + h*w,
chw_scratch.ctypes.data + h*w,
h, w)
chw_scratch[2] = source[..., 2]
ctypes_equalize(chw_scratch.ctypes.data + 2*h*w,
chw_scratch.ctypes.data + 2*h*w,
h, w)
destination[..., 0] = chw_scratch[0]
destination[..., 1] = chw_scratch[1]
destination[..., 2] = chw_scratch[2]


@njit(parallel=False, fastmath=True, inline='always')
def invert(source, destination):
destination[:] = 255 - source


@njit(parallel=False, fastmath=True, inline='always')
def solarize(source, threshold, destination):
invert(source, destination)
destination[:] = np.where(source >= threshold, destination, source)


@njit(parallel=False, fastmath=True, inline='always')
def posterize(source, bits, destination):
mask = ~(2 ** (8 - bits) - 1)
destination[:] = source & mask


@njit(inline='always')
def blend(source1, source2, ratio, destination):
ctypes_add_weighted(source1.ctypes.data, ratio,
source2.ctypes.data, 1 - ratio,
destination.ctypes.data,
source1.shape[0], source1.shape[1])


@njit(parallel=False, fastmath=True, inline='always')
def adjust_saturation(source, scratch, factor, destination):
# TODO numpy autocasting probably allocates memory here,
# should be more careful.
# TODO do we really need scratch for this? could use destination
scratch[...,0] = 0.299 * source[..., 0] + \
0.587 * source[..., 1] + \
0.114 * source[..., 2]
scratch[...,1] = scratch[...,0]
scratch[...,2] = scratch[...,1]

blend(source, scratch, factor, destination)


@njit(parallel=False, fastmath=True, inline='always')
def adjust_contrast(source, scratch, factor, destination):
# TODO assuming 3 channels
scratch[:,:,:] = np.mean(0.299 * source[..., 0] +
0.587 * source[..., 1] +
0.114 * source[..., 2])

blend(source, scratch, factor, destination)


@njit(fastmath=True, inline='always')
def sharpen(source, destination, amount):
ctypes_unsharp_mask(source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])

# in PyTorch's implementation,
# the border is unaffected
destination[0,:] = source[0,:]
destination[1:,0] = source[1:,0]
destination[-1,:] = source[-1,:]
destination[1:-1,-1] = source[1:-1,-1]

blend(source, destination, amount, destination)


"""
Translation, x and y
Assuming this is faster than warpAffine;
also assuming tx and ty are ints
"""
@njit(inline='always')
def translate(source, destination, tx, ty):
if tx > 0:
destination[:, tx:] = source[:, :-tx]
destination[:, :tx] = 0
if tx < 0:
destination[:, :tx] = source[:, -tx:]
destination[:, tx:] = 0
if ty > 0:
destination[ty:, :] = source[:-ty, :]
destination[:ty, :] = 0
if ty < 0:
destination[:ty, :] = source[-ty:, :]
destination[ty:, :] = 0


@njit(inline='always')
def rotate(source, destination, angle):
ctypes_rotate(angle,
source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])


@njit(inline='always')
def shear(source, destination, shear_x, shear_y):
ctypes_shear(shear_x, shear_y,
source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])


@njit(inline='always')
def resize_crop(source, start_row, end_row, start_col, end_col, destination):
Expand Down
67 changes: 66 additions & 1 deletion libffcv/libffcv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,78 @@ extern "C" {
int64_t start_row, int64_t end_row, int64_t start_col, int64_t end_col,
int64_t dest_p, int64_t tx, int64_t ty) {
// TODO use proper arguments type

cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(tx, ty, CV_8UC3, (uint8_t*) dest_p);
cv::resize(source_matrix.colRange(start_col, end_col).rowRange(start_row, end_row),
dest_matrix, dest_matrix.size(), 0, 0, cv::INTER_AREA);
}


EXPORT void rotate(float angle, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);
// TODO unsure if this should be sx, sy
cv::Point2f center((sy-1) / 2.0, (sx-1) / 2.0);
cv::Mat rotation = cv::getRotationMatrix2D(center, angle, 1.0);
cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix, rotation, dest_matrix.size(), cv::INTER_NEAREST);
}

EXPORT void shear(float shear_x, float shear_y, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);

float _shear[6] = { 1, shear_x, 0, shear_y, 1, 0 };

float cx = (sx - 1) / 2.0;
float cy = (sy - 1) / 2.0;

_shear[2] += _shear[0] * -cx + _shear[1] * -cy;
_shear[5] += _shear[3] * -cx + _shear[4] * -cy;

_shear[2] += cx;
_shear[5] += cy;

cv::Mat shear = cv::Mat(2, 3, CV_32F, _shear);
cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix, shear, dest_matrix.size(), cv::INTER_NEAREST);
}

EXPORT void add_weighted(int64_t img1_p, float a, int64_t img2_p, float b, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat img1(sx, sy, CV_8UC3, (uint8_t*) img1_p);
cv::Mat img2(sx, sy, CV_8UC3, (uint8_t*) img2_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);

// TODO doubt we need colRange/rowRange stuff
cv::addWeighted(img1.colRange(0, sy).rowRange(0, sx), a,
img2.colRange(0, sy).rowRange(0, sx), b,
0, dest_matrix);
}

EXPORT void equalize(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8U, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8U, (uint8_t*) dest_p);
cv::equalizeHist(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix);
}

EXPORT void unsharp_mask(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) {
cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p);
cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p);

cv::Point anchor(-1, -1);

// 3x3 kernel, all 1s with 5 in center / sum of kernel
float _kernel[9] = { 0.0769, 0.0769, 0.0769, 0.0769, 0.3846,
0.0769, 0.0769, 0.0769, 0.0769 };
cv::Mat kernel = cv::Mat(3, 3, CV_32F, _kernel);

cv::filter2D(source_matrix.colRange(0, sy).rowRange(0, sx),
dest_matrix, -1, kernel, anchor, 0, cv::BORDER_ISOLATED);

//add_weighted(source_p, amount, dest_p, 1 - amount, dest_p, sx, sy);
}

EXPORT void my_memcpy(void *source, void* dst, uint64_t size) {
memcpy(dst, source, size);
}
Expand Down
1 change: 1 addition & 0 deletions tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example_imgs/*
Loading