|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <algorithm> |
| 16 | +#include "paddle/framework/op_registry.h" |
| 17 | +#include "paddle/platform/cuda_helper.h" |
| 18 | +#include "paddle/platform/gpu_info.h" |
| 19 | + |
| 20 | +namespace paddle { |
| 21 | +namespace operators { |
| 22 | + |
| 23 | +using platform::PADDLE_CUDA_NUM_THREADS; |
| 24 | + |
| 25 | +template <typename T> |
| 26 | +__global__ void FillFirstRow(T* dist, const int N) { |
| 27 | + int idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 28 | + if (idx < N + 1) { |
| 29 | + dist[idx] = idx; |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +template <typename T> |
| 34 | +__global__ void FillFirstColumn(T* dist, const int M, const int N) { |
| 35 | + int idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 36 | + if (idx < M + 1) { |
| 37 | + dist[idx * (N + 1)] = idx; |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +template <typename T> |
| 42 | +__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, |
| 43 | + const int N, const int start) { |
| 44 | + int idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 45 | + int offset = N; |
| 46 | + int index = start + idx * offset; |
| 47 | + int row = index / (N + 1); |
| 48 | + int col = index % (N + 1); |
| 49 | + if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { |
| 50 | + int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; |
| 51 | + int dels = dist[(row - 1) * (N + 1) + col] + 1; |
| 52 | + int ins = dist[row * (N + 1) + col - 1] + 1; |
| 53 | + int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; |
| 54 | + dist[index] = min(dels, min(ins, subs)); |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +template <typename T> |
| 59 | +__global__ void SetOutput(T* out, const T* dist, const int M, const int N, |
| 60 | + bool normalized) { |
| 61 | + int idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 62 | + if (idx == 0) { |
| 63 | + out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +template <typename Place, typename T> |
| 68 | +class EditDistanceGPUKernel : public framework::OpKernel<T> { |
| 69 | + public: |
| 70 | + void Compute(const framework::ExecutionContext& ctx) const { |
| 71 | + auto* out_t = ctx.Output<framework::Tensor>("Out"); |
| 72 | + |
| 73 | + auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps"); |
| 74 | + auto* x2_t = ctx.Input<framework::LoDTensor>("Refs"); |
| 75 | + |
| 76 | + auto normalized = ctx.Attr<bool>("normalized"); |
| 77 | + auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( |
| 78 | + ctx.device_context()) |
| 79 | + .stream(); |
| 80 | + |
| 81 | + auto hyp_lod = x1_t->lod()[0]; |
| 82 | + auto ref_lod = x2_t->lod()[0]; |
| 83 | + PADDLE_ENFORCE( |
| 84 | + hyp_lod.size() == ref_lod.size(), |
| 85 | + "Input(Hyps) and Input(Refs) must have the same batch size."); |
| 86 | + for (size_t i = 1; i < ref_lod.size(); ++i) { |
| 87 | + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], |
| 88 | + "Reference string %d is empty.", i); |
| 89 | + } |
| 90 | + |
| 91 | + auto num_strs = hyp_lod.size() - 1; |
| 92 | + out_t->Resize({static_cast<int64_t>(num_strs), 1}); |
| 93 | + out_t->mutable_data<T>(ctx.GetPlace()); |
| 94 | + auto out = out_t->data<T>(); |
| 95 | + |
| 96 | + T distance = 0.0; |
| 97 | + for (size_t num = 0; num < num_strs; num++) { |
| 98 | + auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]); |
| 99 | + auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]); |
| 100 | + if (m == 0 || n == 0) { |
| 101 | + distance = std::max(m, n); |
| 102 | + if (normalized) { |
| 103 | + PADDLE_ENFORCE(n > 0, |
| 104 | + "The reference string (#%d) cannot be empty " |
| 105 | + "when Attr(normalized) is enabled.", |
| 106 | + n); |
| 107 | + distance = distance / n; |
| 108 | + } |
| 109 | + memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num, |
| 110 | + platform::CPUPlace(), &distance, sizeof(T), stream); |
| 111 | + } else { |
| 112 | + framework::Tensor dist_t; |
| 113 | + dist_t.Resize({m + 1, n + 1}); |
| 114 | + dist_t.mutable_data<T>(ctx.GetPlace()); |
| 115 | + auto dist = dist_t.data<T>(); |
| 116 | + auto x1 = x1_t->data<int>() + hyp_lod[num]; |
| 117 | + auto x2 = x2_t->data<int>() + ref_lod[num]; |
| 118 | + |
| 119 | + FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, |
| 120 | + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); |
| 121 | + |
| 122 | + FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS, |
| 123 | + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); |
| 124 | + // Compute the elements of distance matrix in the anti-diagonal diretion |
| 125 | + for (int64_t slice = 2; slice < m + n + 1; ++slice) { |
| 126 | + int z_m = slice < m + 1 ? 0 : slice - m; |
| 127 | + int z_n = slice < n + 1 ? 0 : slice - n; |
| 128 | + int size = slice - (z_m + z_n) + 1; // number of elments in the same |
| 129 | + // anti-diagonal line to update |
| 130 | + // the start index at which computes from |
| 131 | + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; |
| 132 | + Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, |
| 133 | + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, |
| 134 | + m, n, start); |
| 135 | + } |
| 136 | + SetOutput<T><<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); |
| 137 | + } |
| 138 | + } |
| 139 | + } |
| 140 | +}; |
| 141 | + |
| 142 | +} // namespace operators |
| 143 | +} // namespace paddle |
| 144 | + |
| 145 | +namespace ops = paddle::operators; |
| 146 | + |
| 147 | +REGISTER_OP_CUDA_KERNEL( |
| 148 | + edit_distance, |
| 149 | + ops::EditDistanceGPUKernel<paddle::platform::CUDAPlace, float>); |
0 commit comments