Skip to content

Commit 51c7a35

Browse files
authored
Add input_tensor input type (#2951)
1 parent fa0e3ff commit 51c7a35

File tree

5 files changed

+187
-0
lines changed

5 files changed

+187
-0
lines changed

dlib/cuda/tensor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,14 @@ namespace dlib
680680

681681
// ----------------------------------------------------------------------------------------
682682

683+
inline void memcpy (
684+
alias_tensor_instance&& dest,
685+
const tensor& src
686+
)
687+
{
688+
memcpy(static_cast<tensor&>(dest), src);
689+
}
690+
683691
}
684692

685693
#endif // DLIB_DNn_TENSOR_H_

dlib/cuda/tensor_abstract.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,14 @@ namespace dlib
607607
);
608608
};
609609

610+
inline void memcpy (
611+
alias_tensor_instance&& dest,
612+
const tensor& src
613+
) { memcpy(static_cast<tensor&>(dest), src); }
614+
/*!
615+
A convenient overload for copying from src to dest when you have a temporary alias tensor.
616+
!*/
617+
610618
class alias_tensor_const_instance
611619
{
612620
/*!

dlib/dnn/input.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,93 @@ namespace dlib
10821082
float avg_blue;
10831083
};
10841084

1085+
// ----------------------------------------------------------------------------------------
1086+
1087+
class input_tensor
1088+
{
1089+
public:
1090+
typedef tensor input_type;
1091+
1092+
input_tensor() {}
1093+
input_tensor(const input_tensor&) {}
1094+
1095+
template<typename forward_iterator>
1096+
void to_tensor(
1097+
forward_iterator ibegin,
1098+
forward_iterator iend,
1099+
resizable_tensor& data
1100+
) const
1101+
{
1102+
DLIB_CASSERT(std::distance(ibegin, iend) > 0);
1103+
const auto k = ibegin->k();
1104+
const auto nr = ibegin->nr();
1105+
const auto nc = ibegin->nc();
1106+
// make sure all the input tensors have the same dimensions
1107+
for (auto i = ibegin; i != iend; ++i)
1108+
{
1109+
DLIB_CASSERT(i->k() == k && i->nr() == nr && i->nc() == nc,
1110+
"\t input_tensor::to_tensor()"
1111+
<< "\n\t All tensor objects given to to_tensor() must have the same dimensions."
1112+
<< "\n\t k: " << k
1113+
<< "\n\t nr: " << nr
1114+
<< "\n\t nc: " << nc
1115+
<< "\n\t i->k(): " << i->k()
1116+
<< "\n\t i->nr(): " << i->nr()
1117+
<< "\n\t i->nc(): " << i->nc()
1118+
);
1119+
}
1120+
1121+
const auto num_samples = count_samples(ibegin, iend);
1122+
// initialize data to the right size to contain the stuff in the iterator range.
1123+
data.set_size(num_samples, k, nr, nc);
1124+
1125+
const size_t stride = k * nr * nc;
1126+
size_t offset = 0;
1127+
for (auto i = ibegin; i != iend; ++i)
1128+
{
1129+
alias_tensor slice(i->num_samples(), k, nr, nc);
1130+
memcpy(slice(data, offset), *i);
1131+
offset += slice.num_samples() * stride;
1132+
}
1133+
}
1134+
1135+
friend void serialize(const input_tensor&, std::ostream& out)
1136+
{
1137+
serialize("input_tensor", out);
1138+
}
1139+
1140+
friend void deserialize(input_tensor&, std::istream& in)
1141+
{
1142+
std::string version;
1143+
deserialize(version, in);
1144+
if (version != "input_tensor")
1145+
throw serialization_error("Unexpected version found while deserializing dlib::input_tensor.");
1146+
}
1147+
1148+
friend std::ostream& operator<<(std::ostream& out, const input_tensor&)
1149+
{
1150+
out << "input_tensor";
1151+
return out;
1152+
}
1153+
1154+
friend void to_xml(const input_tensor&, std::ostream& out)
1155+
{
1156+
out << "<input_tensor/>\n";
1157+
}
1158+
1159+
private:
1160+
1161+
template<typename forward_iterator>
1162+
long long count_samples(
1163+
forward_iterator ibegin,
1164+
forward_iterator iend
1165+
) const
1166+
{
1167+
return std::accumulate(ibegin, iend, 0,
1168+
[](long long a, const auto& b) { return a + b.num_samples(); });
1169+
}
1170+
};
1171+
10851172
// ----------------------------------------------------------------------------------------
10861173

10871174
}

dlib/dnn/input_abstract.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,57 @@ namespace dlib
719719

720720
// ----------------------------------------------------------------------------------------
721721

722+
class input_tensor
723+
{
724+
/*!
725+
WHAT THIS OBJECT REPRESENTS
726+
This input layer works with dlib::tensor objects. It is very similar to
727+
the dlib::input layer except that it allows for concatenating data that
728+
already resides in GPU memory.
729+
!*/
730+
731+
public:
732+
typedef tensor input_type;
733+
734+
input_tensor(
735+
);
736+
/*!
737+
ensures
738+
- input_tensor objects are default constructable
739+
!*/
740+
741+
input_tensor(
742+
const input_tensor& item
743+
);
744+
/*!
745+
ensures
746+
- input_tensor objects are copy constructable
747+
!*/
748+
749+
template <typename forward_iterator>
750+
void to_tensor(
751+
forward_iterator ibegin,
752+
forward_iterator iend,
753+
resizable_tensor& data
754+
) const;
755+
/*!
756+
requires
757+
- [ibegin, iend) is an iterator range over input_type objects.
758+
- std::distance(ibegin,iend) > 0
759+
- The input range should contain tensor objects that all have the same
760+
dimensions.
761+
ensures
762+
- Copies the iterator range into #data. In particular, if the input tensors
763+
have R rows, C columns, and K channels then we will have:
764+
- #data.num_samples() == count_samples(ibegin,iend)
765+
- #data.nr() == R
766+
- #data.nc() == C
767+
- #data.k() == K
768+
This results in a tensor concatenation along the sample dimension.
769+
!*/
770+
};
771+
772+
// ----------------------------------------------------------------------------------------
722773
}
723774

724775
#endif // DLIB_DNn_INPUT_ABSTRACT_H_

dlib/test/dnn.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4276,6 +4276,38 @@ namespace
42764276
#endif
42774277
}
42784278

4279+
void test_input_tensor()
4280+
{
4281+
using namespace dlib::tt;
4282+
print_spinner();
4283+
tt::tensor_rand rnd;
4284+
std::vector<resizable_tensor> tensors(3);
4285+
4286+
for (auto& t : tensors) {
4287+
t.set_size(1, 3, 224, 224);
4288+
rnd.fill_gaussian(t);
4289+
}
4290+
4291+
resizable_tensor out;
4292+
input_tensor input_layer;
4293+
4294+
input_layer.to_tensor(tensors.begin(), tensors.end(), out);
4295+
4296+
DLIB_TEST(out.num_samples() == 3);
4297+
DLIB_TEST(out.k() == 3);
4298+
DLIB_TEST(out.nr() == 224);
4299+
DLIB_TEST(out.nc() == 224);
4300+
size_t stride = out.k() * out.nr() * out.nc();
4301+
size_t offset = 0;
4302+
int error = 0;
4303+
4304+
for (auto& t : tensors) {
4305+
error = memcmp(out.host() + offset, t.host(), sizeof(float) * t.size());
4306+
DLIB_TEST(error == 0);
4307+
offset += stride;
4308+
}
4309+
}
4310+
42794311
// ----------------------------------------------------------------------------------------
42804312

42814313
class dnn_tester : public tester
@@ -4386,6 +4418,7 @@ namespace
43864418
test_input_ouput_mappers();
43874419
test_fuse_layers();
43884420
test_reorg();
4421+
test_input_tensor();
43894422
}
43904423

43914424
void perform_test()

0 commit comments

Comments
 (0)