@@ -1082,6 +1082,93 @@ namespace dlib
10821082 float avg_blue;
10831083 };
10841084
1085+ // ----------------------------------------------------------------------------------------
1086+
1087+ class input_tensor
1088+ {
1089+ public:
1090+ typedef tensor input_type;
1091+
1092+ input_tensor () {}
1093+ input_tensor (const input_tensor&) {}
1094+
1095+ template <typename forward_iterator>
1096+ void to_tensor (
1097+ forward_iterator ibegin,
1098+ forward_iterator iend,
1099+ resizable_tensor& data
1100+ ) const
1101+ {
1102+ DLIB_CASSERT (std::distance (ibegin, iend) > 0 );
1103+ const auto k = ibegin->k ();
1104+ const auto nr = ibegin->nr ();
1105+ const auto nc = ibegin->nc ();
1106+ // make sure all the input tensors have the same dimensions
1107+ for (auto i = ibegin; i != iend; ++i)
1108+ {
1109+ DLIB_CASSERT (i->k () == k && i->nr () == nr && i->nc () == nc,
1110+ " \t input_tensor::to_tensor()"
1111+ << " \n\t All tensor objects given to to_tensor() must have the same dimensions."
1112+ << " \n\t k: " << k
1113+ << " \n\t nr: " << nr
1114+ << " \n\t nc: " << nc
1115+ << " \n\t i->k(): " << i->k ()
1116+ << " \n\t i->nr(): " << i->nr ()
1117+ << " \n\t i->nc(): " << i->nc ()
1118+ );
1119+ }
1120+
1121+ const auto num_samples = count_samples (ibegin, iend);
1122+ // initialize data to the right size to contain the stuff in the iterator range.
1123+ data.set_size (num_samples, k, nr, nc);
1124+
1125+ const size_t stride = k * nr * nc;
1126+ size_t offset = 0 ;
1127+ for (auto i = ibegin; i != iend; ++i)
1128+ {
1129+ alias_tensor slice (i->num_samples (), k, nr, nc);
1130+ memcpy (slice (data, offset), *i);
1131+ offset += slice.num_samples () * stride;
1132+ }
1133+ }
1134+
1135+ friend void serialize (const input_tensor&, std::ostream& out)
1136+ {
1137+ serialize (" input_tensor" , out);
1138+ }
1139+
1140+ friend void deserialize (input_tensor&, std::istream& in)
1141+ {
1142+ std::string version;
1143+ deserialize (version, in);
1144+ if (version != " input_tensor" )
1145+ throw serialization_error (" Unexpected version found while deserializing dlib::input_tensor." );
1146+ }
1147+
1148+ friend std::ostream& operator <<(std::ostream& out, const input_tensor&)
1149+ {
1150+ out << " input_tensor" ;
1151+ return out;
1152+ }
1153+
1154+ friend void to_xml (const input_tensor&, std::ostream& out)
1155+ {
1156+ out << " <input_tensor/>\n " ;
1157+ }
1158+
1159+ private:
1160+
1161+ template <typename forward_iterator>
1162+ long long count_samples (
1163+ forward_iterator ibegin,
1164+ forward_iterator iend
1165+ ) const
1166+ {
1167+ return std::accumulate (ibegin, iend, 0 ,
1168+ [](long long a, const auto & b) { return a + b.num_samples (); });
1169+ }
1170+ };
1171+
10851172// ----------------------------------------------------------------------------------------
10861173
10871174}
0 commit comments