Skip to content

Commit d49442e

Browse files
committed
add validation method for conv
1 parent 152b982 commit d49442e

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

src/conv1d.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,43 @@ sum_and_normalize_conv(sycl::queue &Q, span3d_t data, size_t nw) {
3535
return sum;
3636
}
3737

38+
void
39+
validate_conv1d(sycl::queue &Q, span3d_t data, size_t nw) {
40+
const auto n0 = data.extent(0);
41+
const auto n2 = data.extent(2);
42+
43+
sycl::range<1> range0(n0);
44+
sycl::range<1> range2(n2);
45+
46+
Q.parallel_for(range0, [=](unsigned i0) {
47+
for (auto i1 = 0; i1 < nw; ++i1) {
48+
for (auto i2 = 0; i2 < n2 - 1; ++i2) {
49+
if(data(i0, i1, i2) != data(i0, i1, i2+1)){
50+
// throw std::runtime_error("nn");
51+
data(0,0,0) = -1000;
52+
};
53+
}
54+
}
55+
});
56+
57+
Q.parallel_for(range2, [=](unsigned i2) {
58+
for (auto i1 = 0; i1 < nw; ++i1) {
59+
for (auto i0 = 0; i0 < n0 - 1; ++i0) {
60+
if(data(i0, i1, i2) != data(i0+1, i1, i2)){
61+
// throw std::runtime_error("nn");
62+
data(0,0,0) = -2000;
63+
};
64+
}
65+
}
66+
});
67+
68+
Q.wait();
69+
if (data(0,0,0) == -1000 || data(0,0,0) == -2000)
70+
std::cout << "Values at same position i1 are not equivalent throught "
71+
"the batchs"
72+
<< std::endl;
73+
}
74+
3875

3976

4077
// ==========================================
@@ -70,10 +107,10 @@ main(int argc, char **argv) {
70107
const auto n2 = params.n2; // n
71108
const auto k = params.k;
72109

73-
span3d_t data(sycl::malloc_device<real_t>(n0 * n1 * n2, Q), n0,
110+
span3d_t data(sycl::malloc_shared<real_t>(n0 * n1 * n2, Q), n0,
74111
n1, n2);
75112
span3d_t warmup_data(
76-
sycl::malloc_device<real_t>(n0 * n1 * n2, Q), n0, n1, n2);
113+
sycl::malloc_shared<real_t>(n0 * n1 * n2, Q), n0, n1, n2);
77114
Q.wait();
78115

79116
Q.parallel_for(sycl::range<3>(n0, n1, n2), [=](auto itm) {
@@ -86,8 +123,8 @@ main(int argc, char **argv) {
86123
}).wait();
87124

88125
real_t *d_weight =
89-
sycl::malloc_device<real_t>(k * channel_out * channel_in, Q);
90-
real_t *d_bias = sycl::malloc_device<real_t>(channel_out, Q);
126+
sycl::malloc_shared<real_t>(k * channel_out * channel_in, Q);
127+
real_t *d_bias = sycl::malloc_shared<real_t>(channel_out, Q);
91128
Q.wait();
92129
Q.parallel_for(sycl::range<1>(k * channel_out * channel_in), [=](auto itm) {
93130
d_weight[itm] = 1.5;
@@ -136,6 +173,8 @@ main(int argc, char **argv) {
136173
std::cout << "Normalized Array after: " << err << std::endl;
137174
std::cout << std::endl;
138175

176+
validate_conv1d(Q, data, params.n_write);
177+
139178
//==========================================================================
140179
//==========================================================================
141180
std::cout << "PERF_DIAGS:" << std::endl;

0 commit comments

Comments
 (0)