Skip to content

Commit 08e7dc0

Browse files
authored
Update median_filter.cc
1 parent 9cd7f17 commit 08e7dc0

File tree

1 file changed

+8
-40
lines changed

1 file changed

+8
-40
lines changed

src/ops/median_filter.cc

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,25 @@
11
#include "ctranslate2/ops/median_filter.h"
22

3-
#include <algorithm>
4-
5-
#include "cpu/parallel.h"
3+
#include "dispatch.h"
64

75
namespace ctranslate2 {
86
namespace ops {
97

10-
MedianFilter::MedianFilter(const dim_t width)
8+
MedianFilter::MedianFilter(dim_t width)
119
: _width(width)
12-
{
13-
}
10+
{
11+
}
1412

1513
void MedianFilter::operator()(const StorageView& input, StorageView& output) const {
1614
PROFILE("MedianFilter");
1715

18-
if (input.device() != Device::CPU)
19-
throw std::invalid_argument("MedianFilter currently only supports CPU execution");
16+
const dim_t axis = input.rank() - 1;
17+
const dim_t axis_size = input.dim(axis);
2018

2119
output.resize_as(input);
2220

23-
const dim_t depth = input.dim(-1);
24-
const dim_t batch_size = input.size() / depth;
25-
const dim_t rank = _width / 2;
26-
27-
if (depth <= rank)
28-
return;
29-
30-
const auto* src = input.data<float>();
31-
auto* dst = output.data<float>();
32-
33-
cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
34-
StorageView window_storage({_width}, DataType::FLOAT32);
35-
auto* window = window_storage.data<float>();
36-
37-
for (dim_t i = begin; i < end; ++i) {
38-
const dim_t offset = i * depth;
39-
const auto* in = src + offset;
40-
auto* out = dst + offset;
41-
42-
for (dim_t j = 0; j < depth; ++j) {
43-
for (dim_t k = -rank; k <= rank; ++k) {
44-
dim_t read = std::abs(j + k);
45-
if (read >= depth)
46-
read = depth - (read - depth) - 2;
47-
window[k + rank] = in[read];
48-
}
49-
50-
std::nth_element(window, window + rank, window + _width);
51-
out[j] = window[rank];
52-
}
53-
}
54-
});
21+
DEVICE_AND_FLOAT_DISPATCH("MedianFilter", input.device(), input.dtype(),
22+
(compute<D, T>(input, axis_size, output)));
5523
}
5624

5725
}

0 commit comments

Comments
 (0)