Skip to content

Commit 950a916

Browse files
committed
Merge pull request opencv#17752 from YashasSamaga:generalize-concat-fusion-3.4
2 parents 3f13339 + b7eec21 commit 950a916

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

modules/dnn/src/dnn.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,8 +2525,7 @@ struct Net::Impl : public detail::NetImplBase
25252525
// (and so we eliminate the concatenation layer, because the channels
25262526
// are concatenated implicitly).
25272527
Ptr<ConcatLayer> concatLayer = ld.layerInstance.dynamicCast<ConcatLayer>();
2528-
if( !concatLayer.empty() && concatLayer->axis == 1 && !concatLayer->padding &&
2529-
ld.outputBlobs.size() == 1 )
2528+
if( !concatLayer.empty() && !concatLayer->padding && ld.outputBlobs.size() == 1 )
25302529
{
25312530
Mat& output = ld.outputBlobs[0];
25322531
UMat umat_output;
@@ -2563,7 +2562,8 @@ struct Net::Impl : public detail::NetImplBase
25632562
// the concatenation optimization is applied with batch_size > 1.
25642563
// so, for now, we only apply this optimization in the most popular
25652564
// case batch_size == 1.
2566-
if( output.dims == 4 && output.size[0] == 1 )
2565+
int axis = clamp(concatLayer->axis, output.dims);
2566+
if( output.total(0, axis) == 1 )
25672567
{
25682568
size_t i, ninputs = ld.inputBlobsId.size();
25692569
std::vector<LayerPin> realinputs(ninputs);
@@ -2602,14 +2602,14 @@ struct Net::Impl : public detail::NetImplBase
26022602
OpenCLBackendWrapper::update(ld.outputBlobsWrappers, umats);
26032603
}
26042604
#endif
2605-
Range chrange[] = { Range::all(), Range::all(), Range::all(), Range::all() };
2605+
std::vector<Range> chrange(output.dims, Range::all());
26062606
int ofs = 0;
26072607
for( i = 0; i < ninputs; i++ )
26082608
{
26092609
LayerPin pin = realinputs[i];
26102610
LayerData* inp_i_data = &layers[pin.lid];
2611-
int channels_i = ld.inputBlobs[i]->size[1];
2612-
chrange[1] = Range(ofs, ofs + channels_i);
2611+
int channels_i = ld.inputBlobs[i]->size[axis];
2612+
chrange[axis] = Range(ofs, ofs + channels_i);
26132613
printf_(("\toutput %s(%d) to channels (%d, %d)\n", inp_i_data->layerInstance->name.c_str(),
26142614
pin.oid, ofs, ofs + channels_i));
26152615
ofs += channels_i;

0 commit comments

Comments
 (0)