@@ -1407,46 +1407,30 @@ struct Net::Impl
1407
1407
if ( ld.consumers .size () == 1 && pinsToKeep.count (LayerPin (lid, 0 )) == 0 )
1408
1408
{
1409
1409
LayerData* nextData = &layers[ld.consumers [0 ].lid ];
1410
- Ptr<BatchNormLayer> nextBNormLayer =
1411
- nextData->layerInstance .dynamicCast <BatchNormLayer>();
1412
1410
LayerPin lpNext (ld.consumers [0 ].lid , 0 );
1413
- if ( !nextBNormLayer. empty () && pinsToKeep. count (lpNext) == 0 )
1411
+ while (nextData )
1414
1412
{
1415
- LayerData* bnormData = nextData;
1416
- nextData = 0 ;
1417
- if ( currLayer->setBatchNorm (nextBNormLayer) )
1413
+ Ptr<Layer> nextLayer = nextData->layerInstance ;
1414
+ if (currLayer->tryFuse (nextLayer))
1418
1415
{
1419
- printf_ ((" \t fused with %s\n " , nextBNormLayer ->name .c_str ()));
1420
- bnormData ->skip = true ;
1416
+ printf_ ((" \t fused with %s\n " , nextLayer ->name .c_str ()));
1417
+ nextData ->skip = true ;
1421
1418
ld.outputBlobs = layers[lpNext.lid ].outputBlobs ;
1422
1419
ld.outputBlobsWrappers = layers[lpNext.lid ].outputBlobsWrappers ;
1423
- if ( bnormData ->consumers .size () == 1 )
1420
+ if (nextData ->consumers .size () == 1 )
1424
1421
{
1425
- nextData = &layers[bnormData->consumers [0 ].lid ];
1426
- lpNext = LayerPin (bnormData->consumers [0 ].lid , 0 );
1422
+ int nextLayerId = nextData->consumers [0 ].lid ;
1423
+ nextData = &layers[nextLayerId];
1424
+ lpNext = LayerPin (nextLayerId, 0 );
1427
1425
}
1428
- }
1429
- }
1430
-
1431
- Ptr<ScaleLayer> nextScaleLayer;
1432
- if ( nextData )
1433
- nextScaleLayer = nextData->layerInstance .dynamicCast <ScaleLayer>();
1434
- if ( !nextScaleLayer.empty () && pinsToKeep.count (lpNext) == 0 )
1435
- {
1436
- LayerData* scaleData = nextData;
1437
- nextData = 0 ;
1438
- if ( currLayer->setScale (nextScaleLayer) )
1439
- {
1440
- printf_ ((" \t fused with %s\n " , nextScaleLayer->name .c_str ()));
1441
- scaleData->skip = true ;
1442
- ld.outputBlobs = layers[lpNext.lid ].outputBlobs ;
1443
- ld.outputBlobsWrappers = layers[lpNext.lid ].outputBlobsWrappers ;
1444
- if ( scaleData->consumers .size () == 1 )
1426
+ else
1445
1427
{
1446
- nextData = &layers[scaleData-> consumers [ 0 ]. lid ] ;
1447
- lpNext = LayerPin (scaleData-> consumers [ 0 ]. lid , 0 ) ;
1428
+ nextData = 0 ;
1429
+ break ;
1448
1430
}
1449
1431
}
1432
+ else
1433
+ break ;
1450
1434
}
1451
1435
1452
1436
// For now, OpenCL target support fusion with activation of ReLU/ChannelsPReLU/Power/Tanh
@@ -2627,13 +2611,16 @@ Ptr<BackendNode> Layer::tryAttach(const Ptr<BackendNode>& node)
2627
2611
}
2628
2612
2629
2613
bool Layer::setActivation (const Ptr<ActivationLayer>&) { return false ; }
2630
- bool Layer::setBatchNorm (const Ptr<BatchNormLayer>&) { return false ; }
2631
- bool Layer::setScale (const Ptr<ScaleLayer>&) { return false ; }
2614
+ bool Layer::tryFuse (Ptr<Layer>&) { return false ; }
2615
+ void Layer::getScaleShift (Mat& scale, Mat& shift) const
2616
+ {
2617
+ scale = Mat ();
2618
+ shift = Mat ();
2619
+ }
2620
+
2632
2621
void Layer::unsetAttached ()
2633
2622
{
2634
2623
setActivation (Ptr<ActivationLayer>());
2635
- setBatchNorm (Ptr<BatchNormLayer>());
2636
- setScale (Ptr<ScaleLayer>());
2637
2624
}
2638
2625
2639
2626
template <typename T>
0 commit comments