@@ -804,10 +804,14 @@ TEST(Layer, ExpandLayer) {
804
804
testExpandLayer (" seq" , true ); // seq expand to hasSubseq
805
805
}
806
806
807
- void testDegradeLayer (bool hasSubseq, string layer_type, string trans_type) {
807
+ void testDegradeLayer (bool hasSubseq,
808
+ string layer_type,
809
+ string trans_type,
810
+ int stride) {
808
811
TestConfig config;
809
812
config.layerConfig .set_type (layer_type);
810
813
config.layerConfig .set_size (10 );
814
+ config.layerConfig .set_seq_pool_stride (stride);
811
815
config.biasSize = 0 ;
812
816
813
817
config.inputDefs .push_back (
@@ -827,36 +831,46 @@ void testDegradeLayer(bool hasSubseq, string layer_type, string trans_type) {
827
831
if (layer_type == " average" ) {
828
832
for (auto strategy : {" average" , " sum" , " squarerootn" }) {
829
833
LOG (INFO) << " hasSubseq=" << hasSubseq << " trans_type=" << trans_type
830
- << " average_strategy=" << strategy;
834
+ << " average_strategy=" << strategy
835
+ << " seq_pool_stride=" << stride;
831
836
config.layerConfig .set_average_strategy (strategy);
832
837
testDegradeLayerGrad (config, layer_type);
833
838
}
834
839
} else {
835
- LOG (INFO) << " hasSubseq=" << hasSubseq << " trans_type=" << trans_type;
840
+ LOG (INFO) << " hasSubseq=" << hasSubseq << " trans_type=" << trans_type
841
+ << " seq_pool_stride=" << stride;
836
842
testDegradeLayerGrad (config, layer_type);
837
843
}
838
844
}
839
845
840
846
TEST (Layer, MaxLayer) {
841
- testDegradeLayer (false , " max" , " non-seq" ); // seq max to non-seq
842
- testDegradeLayer (true , " max" , " non-seq" ); // hasSubseq max to non-seq
843
- testDegradeLayer (true , " max" , " seq" ); // hasSubseq max to seq
847
+ testDegradeLayer (false , " max" , " non-seq" , - 1 ); // seq max to non-seq
848
+ testDegradeLayer (true , " max" , " non-seq" , - 1 ); // hasSubseq max to non-seq
849
+ testDegradeLayer (true , " max" , " seq" , - 1 ); // hasSubseq max to seq
844
850
}
845
851
846
852
TEST (Layer, SequenceLastInstanceLayer) {
847
853
testDegradeLayer (false ,
848
854
" seqlastins" ,
849
- " non-seq" ); // seq seqlastins to non-seq
855
+ " non-seq" ,
856
+ -1 ); // seq seqlastins to non-seq
857
+ testDegradeLayer (false ,
858
+ " seqlastins" ,
859
+ " non-seq" ,
860
+ 5 ); // seq seqlastins to a shorten seq, stride window = 5
850
861
testDegradeLayer (true ,
851
862
" seqlastins" ,
852
- " non-seq" ); // hasSubseq seqlastins to non-seq
853
- testDegradeLayer (true , " seqlastins" , " seq" ); // hasSubseq seqlastins to seq
863
+ " non-seq" ,
864
+ -1 ); // hasSubseq seqlastins to non-seq
865
+ testDegradeLayer (
866
+ true , " seqlastins" , " seq" , -1 ); // hasSubseq seqlastins to seq
854
867
}
855
868
856
869
TEST (Layer, AverageLayer) {
857
- testDegradeLayer (false , " average" , " non-seq" ); // seq average to non-seq
858
- testDegradeLayer (true , " average" , " non-seq" ); // hasSubseq average to non-seq
859
- testDegradeLayer (true , " average" , " seq" ); // hasSubseq average to seq
870
+ testDegradeLayer (false , " average" , " non-seq" , -1 ); // seq average to non-seq
871
+ testDegradeLayer (
872
+ true , " average" , " non-seq" , -1 ); // hasSubseq average to non-seq
873
+ testDegradeLayer (true , " average" , " seq" , -1 ); // hasSubseq average to seq
860
874
}
861
875
862
876
TEST (Layer, SequenceConcatLayer) {
0 commit comments