Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <THC.h>
#include <THCGeneral.h>

#define CUDA_NUM_THREADS 512
#define THREADS_PER_BLOCK 64
#define CUDA_NUM_THREADS 512
#define THREADS_PER_BLOCK 64

#define DIM0(TENSOR) ((TENSOR).x)
#define DIM1(TENSOR) ((TENSOR).y)
Expand Down Expand Up @@ -82,16 +82,16 @@ __global__ void kernel_ChannelNorm_backward_input1(const int n, const float* inp

void ChannelNorm_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* output, int norm_deg) {
int n = 0;

const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]);
const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]);

const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]);
const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]);

n = THCudaTensor_nElement(state, output);
kernel_ChannelNorm_updateOutput<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>(
n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride,
kernel_ChannelNorm_updateOutput<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream() >>>(
n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride,
norm_deg);

THCudaCheck(cudaGetLastError());
Expand All @@ -113,7 +113,7 @@ void ChannelNorm_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTe
const long4 gradInput1_stride = make_long4(gradInput1->stride[0], gradInput1->stride[1], gradInput1->stride[2], gradInput1->stride[3]);

n = THCudaTensor_nElement(state, gradInput1);
kernel_ChannelNorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>(
kernel_ChannelNorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream() >>>(
n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride,
THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput1), gradInput1_size, gradInput1_stride,
norm_deg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ int corr1d_cuda_forward(THCudaTensor *input1,
int stride1,
int stride2,
int corr_type_multiply
//single_direction=0
//single_direction=0
)
{

Expand Down Expand Up @@ -44,7 +44,7 @@ int corr1d_cuda_forward(THCudaTensor *input1,
int x_shift = -neighborhood_grid_radius_;

// Number of output channels amounts to displacement combinations in X direction only!!
int nOutputPlane = neighborhood_grid_width_;//Same, because 1D X-correlation
int nOutputPlane = neighborhood_grid_width_;//Same, because 1D X-correlation

// Inputs
float * input1_data = THCudaTensor_data(state, input1);
Expand All @@ -64,7 +64,7 @@ int corr1d_cuda_forward(THCudaTensor *input1,
float * rbot1_data = THCudaTensor_data(state, rbot1);
float * rbot2_data = THCudaTensor_data(state, rbot2);

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

int pwidthheight = paddedbottomwidth * paddedbottomheight;

Expand Down Expand Up @@ -145,7 +145,7 @@ int corr1d_cuda_backward(THCudaTensor *input1,

int pwidthheight = paddedbottomwidth * paddedbottomheight;

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

blob_rearrange_ongpu_1d(input1_data,rbot1_data,batchSize,nInputPlane,nInputCols,nInputRows,inputWidthHeight,pad_size,pwidthheight,stream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ int corr_cuda_forward(THCudaTensor *input1,
float * rbot1_data = THCudaTensor_data(state, rbot1);
float * rbot2_data = THCudaTensor_data(state, rbot2);

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

int pwidthheight = paddedbottomwidth * paddedbottomheight;

Expand Down Expand Up @@ -141,7 +141,7 @@ int corr_cuda_backward(THCudaTensor *input1,

int pwidthheight = paddedbottomwidth * paddedbottomheight;

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

blob_rearrange_ongpu(input1_data,rbot1_data,batchSize,nInputPlane,nInputCols,nInputRows,inputWidthHeight,pad_size,pwidthheight,stream);

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +0,0 @@

from torch.utils.ffi import _wrap_function
from ._corr import lib as _lib, ffi as _ffi

__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
__all__.append(symbol)

_import_symbols(locals())
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, st
self.stride2 = stride2
self.corr_multiply = corr_multiply

@staticmethod
def forward(self, input1, input2):

self.save_for_backward(input1, input2)

rbot1 = input1.new()
rbot2 = input2.new()
output = input1.new()
Expand All @@ -33,6 +34,7 @@ def forward(self, input1, input2):

return output

@staticmethod
def backward(self, grad_output):

input1, input2 = self.saved_tensors
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, st
self.stride2 = stride2
self.corr_multiply = corr_multiply

@staticmethod
def forward(self, input1, input2):

self.save_for_backward(input1, input2)
Expand All @@ -91,6 +94,7 @@ def forward(self, input1, input2):

return output

@staticmethod
def backward(self, grad_output):

input1, input2 = self.saved_tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, pad_size=None, kernel_size=None, max_displacement=None,
def reset_params(self):
return

@staticmethod
def forward(self, input1, input2):
return correlation(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)

Expand All @@ -40,6 +41,7 @@ def __init__(self, pad_size=None, kernel_size=None, max_displacement=None,
def reset_params(self):
return

@staticmethod
def forward(self, input1, input2):
return correlation1d(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ int corr1d_cuda_forward(THCudaTensor *input1,
int stride1,
int stride2,
int corr_type_multiply
//single_direction=0
//single_direction=0
)
{

// TODO: Shapechecks

int batchSize = input1->size[0];
int batchSize = input1->size(0);

long nInputPlane = input1->size[1];
long nInputRows = input1->size[2];
long nInputCols = input1->size[3];
long nInputPlane = input1->size(1);
long nInputRows = input1->size(2);
long nInputCols = input1->size(3);
long inputWidthHeight = nInputRows * nInputCols;

long kernel_radius_ = (kernel_size - 1) / 2;
Expand All @@ -44,7 +44,7 @@ int corr1d_cuda_forward(THCudaTensor *input1,
int x_shift = -neighborhood_grid_radius_;

// Number of output channels amounts to displacement combinations in X direction only!!
int nOutputPlane = neighborhood_grid_width_;//Same, because 1D X-correlation
int nOutputPlane = neighborhood_grid_width_;//Same, because 1D X-correlation

// Inputs
float * input1_data = THCudaTensor_data(state, input1);
Expand All @@ -64,7 +64,7 @@ int corr1d_cuda_forward(THCudaTensor *input1,
float * rbot1_data = THCudaTensor_data(state, rbot1);
float * rbot2_data = THCudaTensor_data(state, rbot2);

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

int pwidthheight = paddedbottomwidth * paddedbottomheight;

Expand Down Expand Up @@ -103,10 +103,10 @@ int corr1d_cuda_backward(THCudaTensor *input1,
float * input1_data = THCudaTensor_data(state, input1);
float * input2_data = THCudaTensor_data(state, input2);

long nInputCols = input1->size[3];
long nInputRows = input1->size[2];
long nInputPlane = input1->size[1];
long batchSize = input1->size[0];
long nInputCols = input1->size(3);
long nInputRows = input1->size(2);
long nInputPlane = input1->size(1);
long batchSize = input1->size(0);

// THCudaTensor_resizeAs(state, gradInput1, input1);
// THCudaTensor_resizeAs(state, gradInput2, input2);
Expand Down Expand Up @@ -145,7 +145,7 @@ int corr1d_cuda_backward(THCudaTensor *input1,

int pwidthheight = paddedbottomwidth * paddedbottomheight;

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

blob_rearrange_ongpu_1d(input1_data,rbot1_data,batchSize,nInputPlane,nInputCols,nInputRows,inputWidthHeight,pad_size,pwidthheight,stream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ int corr_cuda_forward(THCudaTensor *input1,

// TODO: Shapechecks

int batchSize = input1->size[0];
int batchSize = input1->size(0);

long nInputPlane = input1->size[1];
long nInputRows = input1->size[2];
long nInputCols = input1->size[3];
long nInputPlane = input1->size(1);
long nInputRows = input1->size(2);
long nInputCols = input1->size(3);
long inputWidthHeight = nInputRows * nInputCols;

long kernel_radius_ = (kernel_size - 1) / 2;
Expand Down Expand Up @@ -62,7 +62,7 @@ int corr_cuda_forward(THCudaTensor *input1,
float * rbot1_data = THCudaTensor_data(state, rbot1);
float * rbot2_data = THCudaTensor_data(state, rbot2);

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

int pwidthheight = paddedbottomwidth * paddedbottomheight;

Expand Down Expand Up @@ -100,10 +100,10 @@ int corr_cuda_backward(THCudaTensor *input1,
float * input1_data = THCudaTensor_data(state, input1);
float * input2_data = THCudaTensor_data(state, input2);

long nInputCols = input1->size[3];
long nInputRows = input1->size[2];
long nInputPlane = input1->size[1];
long batchSize = input1->size[0];
long nInputCols = input1->size(3);
long nInputRows = input1->size(2);
long nInputPlane = input1->size(1);
long batchSize = input1->size(0);

// THCudaTensor_resizeAs(state, gradInput1, input1);
// THCudaTensor_resizeAs(state, gradInput2, input2);
Expand Down Expand Up @@ -141,7 +141,7 @@ int corr_cuda_backward(THCudaTensor *input1,

int pwidthheight = paddedbottomwidth * paddedbottomheight;

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

blob_rearrange_ongpu(input1_data,rbot1_data,batchSize,nInputPlane,nInputCols,nInputRows,inputWidthHeight,pad_size,pwidthheight,stream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,40 @@

from setuptools import setup, find_packages

import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension

this_file = os.path.dirname(__file__)


corr_source = ['correlation_package/src/corr.c', 'correlation_package/src/corr1d.c']
corr_includes = ['correlation_package/src/']

if torch.cuda.is_available():
print('Including CUDA code.')
ext_fnct = CUDAExtension
corr_source += ['correlation_package/src/corr_cuda.c', 'correlation_package/src/corr1d_cuda.c']
corr_source += ['correlation_package/src/corr_cuda_kernel.cu', 'correlation_package/src/corr1d_cuda_kernel.cu']
else:
ext_fnct = CppExtension

setup(
name="correlation_package",
version="0.1",
description="Correlation layer from FlowNetC",
url="https://github.com/jbarker-nvidia/pytorch-correlation",
author="Jon Barker",
author_email="[email protected]",
# Require cffi
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
# Exclude the build files.
packages=find_packages(exclude=["build"]),
# Package where to put the extensions. Has to be a prefix of build.py
ext_package="",
# Extensions to compile
cffi_modules=[
os.path.join(this_file, "build.py:ffi")
ext_modules=[
ext_fnct(
'correlation_package._ext.corr',
corr_source, include_dirs=corr_includes,
extra_compile_args={'cxx': ['-std=c++14']},),
],
cmdclass={'build_ext': BuildExtension}
)
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_correlation1d_0():


model2 = correlation1d(1, 1, 1, 1, 1, 1)
y2 = model2(A_, B_)
y2 = model2.apply(A_, B_)
print(y2) # should be 1x3x2x2

return
Expand All @@ -113,7 +113,7 @@ def test_correlation1d():

#import pdb; pdb.set_trace()
model = correlation1d(20, 1, 20, 1, 1, 1)
y = model(A_, B_)
y = model.apply(A_, B_)
print(y.size())

print('Functional interface test passed')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
#!/usr/bin/env bash

CUDA_PATH=/usr/local/cuda-8.0

cd correlation-pytorch/correlation_package/src
echo "Compiling correlation layer kernels by nvcc..."

# TODO (JEB): Check which arches we need
nvcc -c -o corr_cuda_kernel.cu.o corr_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
nvcc -c -o corr1d_cuda_kernel.cu.o corr1d_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52

cd ../../
python setup.py build install
python setup.py build install --user
Loading