|
1 | 1 | #include "ctranslate2/ops/median_filter.h" |
2 | 2 |
|
3 | | -#include <algorithm> |
4 | | - |
5 | | -#include "cpu/parallel.h" |
| 3 | +#include "dispatch.h" |
6 | 4 |
|
7 | 5 | namespace ctranslate2 { |
8 | 6 | namespace ops { |
9 | 7 |
|
10 | | - MedianFilter::MedianFilter(const dim_t width) |
| 8 | + MedianFilter::MedianFilter(dim_t width) |
11 | 9 | : _width(width) |
12 | | - { |
13 | | - } |
| 10 | + { |
| 11 | + } |
14 | 12 |
|
15 | 13 | void MedianFilter::operator()(const StorageView& input, StorageView& output) const { |
16 | 14 | PROFILE("MedianFilter"); |
17 | 15 |
|
18 | | - if (input.device() != Device::CPU) |
19 | | - throw std::invalid_argument("MedianFilter currently only supports CPU execution"); |
| 16 | + const dim_t axis = input.rank() - 1; |
| 17 | + const dim_t axis_size = input.dim(axis); |
20 | 18 |
|
21 | 19 | output.resize_as(input); |
22 | 20 |
|
23 | | - const dim_t depth = input.dim(-1); |
24 | | - const dim_t batch_size = input.size() / depth; |
25 | | - const dim_t rank = _width / 2; |
26 | | - |
27 | | - if (depth <= rank) |
28 | | - return; |
29 | | - |
30 | | - const auto* src = input.data<float>(); |
31 | | - auto* dst = output.data<float>(); |
32 | | - |
33 | | - cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { |
34 | | - StorageView window_storage({_width}, DataType::FLOAT32); |
35 | | - auto* window = window_storage.data<float>(); |
36 | | - |
37 | | - for (dim_t i = begin; i < end; ++i) { |
38 | | - const dim_t offset = i * depth; |
39 | | - const auto* in = src + offset; |
40 | | - auto* out = dst + offset; |
41 | | - |
42 | | - for (dim_t j = 0; j < depth; ++j) { |
43 | | - for (dim_t k = -rank; k <= rank; ++k) { |
44 | | - dim_t read = std::abs(j + k); |
45 | | - if (read >= depth) |
46 | | - read = depth - (read - depth) - 2; |
47 | | - window[k + rank] = in[read]; |
48 | | - } |
49 | | - |
50 | | - std::nth_element(window, window + rank, window + _width); |
51 | | - out[j] = window[rank]; |
52 | | - } |
53 | | - } |
54 | | - }); |
| 21 | + DEVICE_AND_FLOAT_DISPATCH("MedianFilter", input.device(), input.dtype(), |
| 22 | + (compute<D, T>(input, axis_size, output))); |
55 | 23 | } |
56 | 24 |
|
57 | 25 | } |
|
0 commit comments