@@ -38,20 +38,18 @@ namespace datasets
3838class ReorderLayerDataset
3939{
4040public:
41- using type = std::tuple<TensorShape, TensorShape, WeightFormat, WeightFormat, bool >;
41+ using type = std::tuple<TensorShape, TensorShape, WeightFormat, WeightFormat>;
4242
4343 struct iterator
4444 {
4545 iterator (std::vector<TensorShape>::const_iterator in_it,
4646 std::vector<TensorShape>::const_iterator out_it,
4747 std::vector<WeightFormat>::const_iterator _wf_in_it,
48- std::vector<WeightFormat>::const_iterator _wf_out_it,
49- std::vector<bool >::const_iterator _transposes_it)
48+ std::vector<WeightFormat>::const_iterator _wf_out_it)
5049 : _in_it{ std::move (in_it) },
5150 _out_it{ std::move (out_it) },
5251 _wf_in_it{ std::move (_wf_in_it) },
53- _wf_out_it{ std::move (_wf_out_it) },
54- _transposes_it{ std::move (_transposes_it) }
52+ _wf_out_it{ std::move (_wf_out_it) }
5553 {
5654 }
5755
@@ -62,13 +60,12 @@ class ReorderLayerDataset
6260 description << " Out=" << *_out_it << " :" ;
6361 description << " Wf_In=" << *_wf_in_it << " :" ;
6462 description << " Wf_Out=" << *_wf_out_it;
65- description << " Transpose=" << *_transposes_it;
6663 return description.str ();
6764 }
6865
6966 ReorderLayerDataset::type operator *() const
7067 {
71- return std::make_tuple (*_in_it, *_out_it, *_wf_in_it, *_wf_out_it, *_transposes_it );
68+ return std::make_tuple (*_in_it, *_out_it, *_wf_in_it, *_wf_out_it);
7269 }
7370
7471 iterator &operator ++()
@@ -77,7 +74,6 @@ class ReorderLayerDataset
7774 ++_out_it;
7875 ++_wf_in_it;
7976 ++_wf_out_it;
80- ++_transposes_it;
8177
8278 return *this ;
8379 }
@@ -87,26 +83,24 @@ class ReorderLayerDataset
8783 std::vector<TensorShape>::const_iterator _out_it;
8884 std::vector<WeightFormat>::const_iterator _wf_in_it;
8985 std::vector<WeightFormat>::const_iterator _wf_out_it;
90- std::vector<bool >::const_iterator _transposes_it;
9186 };
9287
9388 iterator begin () const
9489 {
95- return iterator (_in_shapes.begin (), _out_shapes.begin (), _in_wfs.begin (), _out_wfs.begin (), _transposes. begin () );
90+ return iterator (_in_shapes.begin (), _out_shapes.begin (), _in_wfs.begin (), _out_wfs.begin ());
9691 }
9792
9893 int size () const
9994 {
100- return std::min (_in_shapes.size (), std::min (_out_shapes.size (), std::min (_in_wfs.size (), std::min ( _out_wfs.size (), _transposes. size () ))));
95+ return std::min (_in_shapes.size (), std::min (_out_shapes.size (), std::min (_in_wfs.size (), _out_wfs.size ())));
10196 }
10297
103- void add_config (TensorShape in, TensorShape out, WeightFormat in_wf, WeightFormat out_wf, bool transpose )
98+ void add_config (TensorShape in, TensorShape out, WeightFormat in_wf, WeightFormat out_wf)
10499 {
105100 _in_shapes.emplace_back (std::move (in));
106101 _out_shapes.emplace_back (std::move (out));
107102 _in_wfs.emplace_back (std::move (in_wf));
108103 _out_wfs.emplace_back (std::move (out_wf));
109- _transposes.emplace_back (transpose);
110104 }
111105
112106 // protected:
@@ -118,7 +112,6 @@ class ReorderLayerDataset
118112 std::vector<TensorShape> _out_shapes{};
119113 std::vector<WeightFormat> _in_wfs{};
120114 std::vector<WeightFormat> _out_wfs{};
121- std::vector<bool > _transposes{};
122115};
123116
124117/* * [ReorderLayer datasets] **/
@@ -128,16 +121,16 @@ class ReorderLayerDatasetBlock4 final : public ReorderLayerDataset
128121 public:
129122 ReorderLayerDatasetBlock4 ()
130123 {
131- add_config (TensorShape (10U , 9U ), TensorShape (10U , 12U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
132- add_config (TensorShape (16U , 16U ), TensorShape (16U , 16U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
133- add_config (TensorShape (10U , 511U ), TensorShape (10U , 512U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
134- add_config (TensorShape (234U , 301U ), TensorShape (234U , 304U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
135- add_config (TensorShape (1024U , 1024U ), TensorShape (1024U , 1024U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
136- add_config (TensorShape (10U , 9U , 1U , 1U ), TensorShape (10U , 12U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
137- add_config (TensorShape (16U , 16U , 1U , 1U ), TensorShape (16U , 16U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
138- add_config (TensorShape (10U , 511U , 1U , 1U ), TensorShape (10U , 512U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
139- add_config (TensorShape (234U , 301U , 1U , 1U ), TensorShape (234U , 304U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
140- add_config (TensorShape (1024U , 1024U , 1U , 1U ), TensorShape (1024U , 1024U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4, true );
124+ add_config (TensorShape (10U , 9U ), TensorShape (10U , 12U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
125+ add_config (TensorShape (16U , 16U ), TensorShape (16U , 16U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
126+ add_config (TensorShape (10U , 511U ), TensorShape (10U , 512U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
127+ add_config (TensorShape (234U , 301U ), TensorShape (234U , 304U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
128+ add_config (TensorShape (1024U , 1024U ), TensorShape (1024U , 1024U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
129+ add_config (TensorShape (10U , 9U , 1U , 1U ), TensorShape (10U , 12U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
130+ add_config (TensorShape (16U , 16U , 1U , 1U ), TensorShape (16U , 16U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
131+ add_config (TensorShape (10U , 511U , 1U , 1U ), TensorShape (10U , 512U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
132+ add_config (TensorShape (234U , 301U , 1U , 1U ), TensorShape (234U , 304U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
133+ add_config (TensorShape (1024U , 1024U , 1U , 1U ), TensorShape (1024U , 1024U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo4);
141134 }
142135};
143136
@@ -146,16 +139,16 @@ class ReorderLayerDatasetBlock8 final : public ReorderLayerDataset
146139 public:
147140 ReorderLayerDatasetBlock8 ()
148141 {
149- add_config (TensorShape (10U , 9U ), TensorShape (10U , 16U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
150- add_config (TensorShape (16U , 16U ), TensorShape (16U , 16U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
151- add_config (TensorShape (10U , 511U ), TensorShape (10U , 512U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
152- add_config (TensorShape (234U , 301U ), TensorShape (234U , 304U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
153- add_config (TensorShape (1024U , 1024U ), TensorShape (1024U , 1024U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
154- add_config (TensorShape (10U , 9U , 1U , 1U ), TensorShape (10U , 16U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
155- add_config (TensorShape (16U , 16U , 1U , 1U ), TensorShape (16U , 16U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
156- add_config (TensorShape (10U , 511U , 1U , 1U ), TensorShape (10U , 512U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
157- add_config (TensorShape (234U , 301U , 1U , 1U ), TensorShape (234U , 304U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
158- add_config (TensorShape (1024U , 1024U , 1U , 1U ), TensorShape (1024U , 1024U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8, true );
142+ add_config (TensorShape (10U , 9U ), TensorShape (10U , 16U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
143+ add_config (TensorShape (16U , 16U ), TensorShape (16U , 16U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
144+ add_config (TensorShape (10U , 511U ), TensorShape (10U , 512U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
145+ add_config (TensorShape (234U , 301U ), TensorShape (234U , 304U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
146+ add_config (TensorShape (1024U , 1024U ), TensorShape (1024U , 1024U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
147+ add_config (TensorShape (10U , 9U , 1U , 1U ), TensorShape (10U , 16U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
148+ add_config (TensorShape (16U , 16U , 1U , 1U ), TensorShape (16U , 16U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
149+ add_config (TensorShape (10U , 511U , 1U , 1U ), TensorShape (10U , 512U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
150+ add_config (TensorShape (234U , 301U , 1U , 1U ), TensorShape (234U , 304U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
151+ add_config (TensorShape (1024U , 1024U , 1U , 1U ), TensorShape (1024U , 1024U , 1U , 1U ), WeightFormat::OHWI, WeightFormat::OHWIo8);
159152 }
160153};
161154
0 commit comments