Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 104 additions & 42 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
reg64_t reg_ow_pos = rdx;
reg64_t aux_reg_output = reg_ow_pos;
reg64_t reg_dg_iter = reg_output;
reg64_t reg_gr_iter = rsp;
reg64_t aux_reg_input = rax;
reg64_t aux2_reg_input = reg_kernel;
reg64_t reg_ic_iter = rbx;
Expand Down Expand Up @@ -163,6 +164,10 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
dd(1);
}

for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
dd(-1);
}
}

void apply_filter(int ow_step, int oc_blocks_step, int oc_step, int ic_step) {
Expand Down Expand Up @@ -359,16 +364,32 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
cvtsi2ss(xmm_ih_im, reg_tmp_32);
addss(xmm_ih_im, xmm_map_h);

movss(xmm_tmp, xmm_ih_im);
cmpss(xmm_tmp, table_val(0), 1);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
jne(init_with_zeros, T_NEAR);
if (jcp_.with_bi_pad) {
movss(xmm_tmp, xmm_ih_im);
cvtps2dq(xmm_tmp, xmm_tmp);
cmpss(xmm_tmp, table_val(6), 0x0e);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
jne(init_with_zeros, T_NEAR);

movss(xmm_tmp, xmm_ih_im);
cvtps2dq(xmm_tmp, xmm_tmp);
cmpss(xmm_tmp, table_val(1), 1);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
je(init_with_zeros, T_NEAR);
} else {
movss(xmm_tmp, xmm_ih_im);
cmpss(xmm_tmp, table_val(0), 1);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
jne(init_with_zeros, T_NEAR);

cmpss(xmm_ih_im, table_val(1), 1);
movq(reg_tmp_64, xmm_ih_im);
cmp(reg_tmp_32, 0);
je(init_with_zeros, T_NEAR);
cmpss(xmm_ih_im, table_val(1), 1);
movq(reg_tmp_64, xmm_ih_im);
cmp(reg_tmp_32, 0);
je(init_with_zeros, T_NEAR);
}


size_t def_off_w = ((2 * (kh * jcp_.kw + kw) + 1) * jcp_.oh * jcp_.ow) + ow;
Expand All @@ -387,16 +408,33 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
cvtsi2ss(xmm_iw_im, reg_tmp_32);
addss(xmm_iw_im, xmm_map_w);

movss(xmm_tmp, xmm_iw_im);
cmpss(xmm_tmp, table_val(0), 1);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
jne(init_with_zeros, T_NEAR);
if (jcp_.with_bi_pad) {
movss(xmm_tmp, xmm_iw_im);
cvtps2dq(xmm_tmp, xmm_tmp);
cmpss(xmm_tmp, table_val(6), 0x0e);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
jne(init_with_zeros, T_NEAR);

movss(xmm_tmp, xmm_iw_im);
cvtps2dq(xmm_tmp, xmm_tmp);
cmpss(xmm_tmp, table_val(2), 1);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
je(init_with_zeros, T_NEAR);
} else {
movss(xmm_tmp, xmm_iw_im);
cmpss(xmm_tmp, table_val(0), 1);
movq(reg_tmp_64, xmm_tmp);
cmp(reg_tmp_32, 0);
jne(init_with_zeros, T_NEAR);

cmpss(xmm_iw_im, table_val(2), 1);
movq(reg_tmp_64, xmm_iw_im);
cmp(reg_tmp_32, 0);
je(init_with_zeros, T_NEAR);
}

cmpss(xmm_iw_im, table_val(2), 1);
movq(reg_tmp_64, xmm_iw_im);
cmp(reg_tmp_32, 0);
je(init_with_zeros, T_NEAR);

// interpolation calculation

Expand Down Expand Up @@ -853,7 +891,6 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
L(oc_unrolled_loop); {
cmp(reg_oc_work, jcp_.nb_oc_blocking * jcp_.oc_block);
jl(oc_main_loop, T_NEAR);

ic_loop(ow_step, jcp_.nb_oc_blocking, jcp_.oc_block);
store_output(ow_step, jcp_.nb_oc_blocking, jcp_.oc_block);

Expand All @@ -869,7 +906,6 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
L(oc_main_loop); {
cmp(reg_oc_work, jcp_.oc_block);
jl(oc_tail, T_NEAR);

ic_loop(ow_step, 1, jcp_.oc_block);
store_output(ow_step, 1, jcp_.oc_block);

Expand Down Expand Up @@ -967,6 +1003,12 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

const int simd_w = mayiuse(cpu::x64::avx512_common) ? 16 : 8;
if (group != 1 && (((getParentEdgeAt(0)->getDims()[1] / group) % simd_w != 0)
|| ((getChildEdgeAt(0)->getDims()[1] / group) % simd_w != 0))) {
enforceRef = true;
}

size_t inputsNumber = getOriginalInputsNumber();
InferenceEngine::LayerConfig config;
config.dynBatchSupport = false;
Expand All @@ -987,7 +1029,9 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
config.outConfs[0].inPlace = -1;

impl_desc_type impl_type;
if (mayiuse(cpu::x64::avx512_common)) {
if (enforceRef) {
impl_type = impl_desc_type::ref;
} else if (mayiuse(cpu::x64::avx512_common)) {
impl_type = impl_desc_type::jit_avx512;
} else if (mayiuse(cpu::x64::avx2)) {
impl_type = impl_desc_type::jit_avx2;
Expand All @@ -997,8 +1041,8 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
impl_type = impl_desc_type::ref;
}

if (mayiuse(cpu::x64::sse41)) {
// optimzed implementation
if (!enforceRef && mayiuse(cpu::x64::sse41)) {
// optimized implementation
auto dataFormat = memory::format_tag::nhwc;
auto offFormat = memory::format_tag::nchw;
auto weiFormat = group > 1 ? mayiuse(avx512_common) ? memory::format_tag::gOIhw16i16o : memory::format_tag::gOIhw8i8o
Expand Down Expand Up @@ -1062,9 +1106,9 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {
jcp.oh = dstDims[2];
jcp.ow = dstDims[3];

bool with_groups = group > 1;
jcp.kh = weiDims[with_groups + 2];
jcp.kw = weiDims[with_groups + 3];
// bool with_groups = group > 1;
jcp.kh = weiDims[2];
jcp.kw = weiDims[3];

jcp.t_pad = paddingL[0];
jcp.l_pad = paddingL[1];
Expand Down Expand Up @@ -1097,7 +1141,9 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {

jcp.nthr = dnnl_get_max_threads();

if (mayiuse(cpu::x64::avx512_common)) {
if (enforceRef) {
return;
} else if (mayiuse(cpu::x64::avx512_common)) {
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
} else if (mayiuse(cpu::x64::avx2)) {
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
Expand Down Expand Up @@ -1147,7 +1193,7 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f

for (int ic = 0; ic < IC; ic++) {
const float *data_im_ptr = src + mb * src_strides[0] + (g * IC + ic) * src_strides[1] + h_in * src_strides[2] + w_in * src_strides[3];
const int deformable_group_index = ic / channel_per_deformable_group;
const int deformable_group_index = (IC * g + ic) / channel_per_deformable_group;
const float *data_offset_ptr = offsets + mb * off_strides[0] + (deformable_group_index * 2 * KH * KW) * off_strides[1];
const float *modulation_offset_ptr = nullptr;
if (modulation != nullptr) {
Expand All @@ -1165,22 +1211,38 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f

const float h_im = h_in + map_h; // absolute pixel index with offset
const float w_im = w_in + map_w; // absolute pixel index with offset
if (h_im >= 0 && w_im >= 0 && h_im < IH && w_im < IW) {
const int cur_height = IH - h_in;
const int cur_width = IW - w_in;
int h_low = std::max(static_cast<int>(floorf(map_h)), 0);
int w_low = std::max(static_cast<int>(floorf(map_w)), 0);
int h_high = with_bi_pad ? h_low + 1 : std::min(static_cast<int>(ceilf(map_h)), cur_height - 1);
int w_high = with_bi_pad ? w_low + 1 : std::min(static_cast<int>(ceilf(map_w)), cur_width - 1);
bool skip_compute;
if (with_bilinear_pad) {
skip_compute = !(static_cast<int>(w_im) > -1 &&
static_cast<int>(w_im) < IW &&
static_cast<int>(h_im) > -1 &&
static_cast<int>(h_im) < IH);
} else {
skip_compute = !(w_im >= 0 &&
w_im < IW &&
h_im >= 0 &&
h_im < IH);
}
if (!skip_compute) {
const int cur_h_end = IH - h_in;
const int cur_w_end = IW - w_in;
int h_low = with_bi_pad ? static_cast<int>(floorf(map_h)) :
std::max(static_cast<int>(floorf(map_h)), 0);
int w_low = with_bi_pad ? static_cast<int>(floorf(map_w)) :
std::max(static_cast<int>(floorf(map_w)), 0);
const int cur_h_start = h_low + h_in;
const int cur_w_start = w_low + w_in;
int h_high = with_bi_pad ? h_low + 1 : std::min(static_cast<int>(ceilf(map_h)), cur_h_end - 1);
int w_high = with_bi_pad ? w_low + 1 : std::min(static_cast<int>(ceilf(map_w)), cur_w_end - 1);

float lh = map_h - h_low;
float lw = map_w - w_low;
float hh = 1 - lh, hw = 1 - lw;

float v1 = (w_low >= 0 && h_low >= 0) ? data_im_ptr[h_low * src_strides[2] + w_low * src_strides[3]] : 0.0f;
float v2 = (w_high < cur_width && h_low >= 0) ? data_im_ptr[h_low * src_strides[2] + w_high * src_strides[3]] : 0.0f;
float v3 = (w_low >= 0 && h_high < cur_height) ? data_im_ptr[h_high * src_strides[2] + w_low * src_strides[3]] : 0.0f;
float v4 = (w_high < cur_width && h_high < cur_height) ? data_im_ptr[h_high * src_strides[2] + w_high * src_strides[3]] : 0.0f;
float v1 = (cur_w_start >= 0 && cur_h_start >= 0) ? data_im_ptr[h_low * src_strides[2] + w_low * src_strides[3]] : 0.0f;
float v2 = (w_high < cur_w_end && cur_h_start >= 0) ? data_im_ptr[h_low * src_strides[2] + w_high * src_strides[3]] : 0.0f;
float v3 = (cur_w_start >= 0 && h_high < cur_h_end) ? data_im_ptr[h_high * src_strides[2] + w_low * src_strides[3]] : 0.0f;
float v4 = (w_high < cur_w_end && h_high < cur_h_end) ? data_im_ptr[h_high * src_strides[2] + w_high * src_strides[3]] : 0.0f;
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
Expand All @@ -1192,8 +1254,8 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
modulation_scalar = modulation_offset_ptr[modulation_index];
}

const float weight = with_groups ? weights[g * wei_strides[0] + oc * wei_strides[1] + ic * wei_strides[2] + kh * wei_strides[3] +
kw * wei_strides[4]]
const float weight = with_groups ? weights[(g + oc / G) * wei_strides[0] + ic * wei_strides[1] + kh * wei_strides[2] +
kw * wei_strides[3]]
: weights[oc * wei_strides[0] + ic * wei_strides[1] + kh * wei_strides[2] + kw * wei_strides[3]];
d += val * weight * modulation_scalar;
}
Expand All @@ -1205,7 +1267,7 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
};

parallel_nd(G, MB, OC, OH, OW,
[&](int g, int mb, int oc, int oh, int ow) {
[&](int g, int mb, int oc, int oh, int ow) {
dst[mb * dst_strides[0] + (g * OC + oc) * dst_strides[1] + oh * dst_strides[2] + ow * dst_strides[3]] = ker(g, mb, oc, oh, ow);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class MKLDNNDeformableConvolutionNode : public MKLDNNNode {
bool canBeInPlace() const override {
return false;
}
bool enforceRef = false;

InferenceEngine::Precision getRuntimePrecision() const override;

Expand Down