|
1 |
| -#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 1 | +#include <libtorchaudio/utils.h> |
2 | 2 | #include <torch/csrc/stable/library.h>
|
3 |
| -#include <torch/csrc/stable/ops.h> |
4 |
| -#include <torch/csrc/stable/tensor.h> |
5 |
| -#include <torch/script.h> |
6 |
| -#include <torch/torch.h> |
7 |
| - |
8 |
| -using namespace std; |
9 | 3 |
|
10 | 4 | namespace torchaudio {
|
11 | 5 | namespace alignment {
|
12 | 6 | namespace cpu {
|
| 7 | + |
| 8 | +using torch::stable::Tensor; |
| 9 | +using torch::headeronly::ScalarType; |
| 10 | + |
13 | 11 | // Inspired from
|
14 | 12 | // https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
|
15 |
| -template <typename scalar_t, at::ScalarType target_scalar_type> |
| 13 | +template <typename scalar_t, ScalarType target_scalar_type> |
16 | 14 | void forced_align_impl(
|
17 |
| - const torch::Tensor& logProbs, |
18 |
| - const torch::Tensor& targets, |
| 15 | + const Tensor& logProbs, |
| 16 | + const Tensor& targets, |
19 | 17 | const int64_t blank,
|
20 |
| - torch::Tensor& paths) { |
| 18 | + Tensor& paths) { |
21 | 19 | const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
|
22 | 20 | using target_t = typename std::
|
23 |
| - conditional<target_scalar_type == torch::kInt, int, int64_t>::type; |
| 21 | + conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type; |
24 | 22 | const auto batchIndex =
|
25 | 23 | 0; // TODO: support batch version and use the real batch index
|
26 | 24 | const auto T = logProbs.size(1);
|
@@ -138,73 +136,111 @@ void forced_align_impl(
|
138 | 136 | delete[] backPtr_a;
|
139 | 137 | }
|
140 | 138 |
|
141 |
| -std::tuple<torch::Tensor, torch::Tensor> compute( |
142 |
| - const torch::Tensor& logProbs, |
143 |
| - const torch::Tensor& targets, |
144 |
| - const torch::Tensor& inputLengths, |
145 |
| - const torch::Tensor& targetLengths, |
| 139 | +std::tuple<Tensor, Tensor> compute( |
| 140 | + const Tensor& logProbs, |
| 141 | + const Tensor& targets, |
| 142 | + const Tensor& inputLengths, |
| 143 | + const Tensor& targetLengths, |
146 | 144 | const int64_t blank) {
|
147 |
| - TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); |
148 |
| - TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); |
149 |
| - TORCH_CHECK( |
150 |
| - logProbs.device() == targets.device(), |
151 |
| - "log_probs and targets need to be on the same device"); |
152 |
| - TORCH_CHECK( |
153 |
| - logProbs.dtype() == torch::kFloat64 || |
154 |
| - logProbs.dtype() == torch::kFloat32 || |
155 |
| - logProbs.dtype() == torch::kFloat16, |
| 145 | + STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); |
| 146 | + STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); |
| 147 | + STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor"); |
| 148 | + STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor"); |
| 149 | + STD_TORCH_CHECK( |
| 150 | + logProbs.scalar_type() == ScalarType::Double || |
| 151 | + logProbs.scalar_type() == ScalarType::Float || |
| 152 | + logProbs.scalar_type() == ScalarType::Half, |
156 | 153 | "log_probs must be float64, float32 or float16 (half) type");
|
157 |
| - TORCH_CHECK( |
158 |
| - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, |
| 154 | + STD_TORCH_CHECK( |
| 155 | + targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long, |
159 | 156 | "targets must be int32 or int64 type");
|
160 |
| - TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); |
161 |
| - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
162 |
| - TORCH_CHECK( |
| 157 | + STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); |
| 158 | + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
| 159 | + STD_TORCH_CHECK( |
163 | 160 | logProbs.dim() == 3,
|
164 | 161 | "log_probs must be 3-D (batch_size, input length, num classes)");
|
165 |
| - TORCH_CHECK( |
| 162 | + STD_TORCH_CHECK( |
166 | 163 | targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
|
167 |
| - TORCH_CHECK( |
| 164 | + STD_TORCH_CHECK( |
168 | 165 | inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
|
169 |
| - TORCH_CHECK( |
| 166 | + STD_TORCH_CHECK( |
170 | 167 | targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
|
171 |
| - TORCH_CHECK( |
| 168 | + STD_TORCH_CHECK( |
172 | 169 | logProbs.size(0) == 1,
|
173 | 170 | "The batch dimension for log_probs must be 1 at the current version.")
|
174 |
| - TORCH_CHECK( |
| 171 | + STD_TORCH_CHECK( |
175 | 172 | targets.size(0) == 1,
|
176 | 173 | "The batch dimension for targets must be 1 at the current version.")
|
177 |
| - TORCH_CHECK( |
| 174 | + STD_TORCH_CHECK( |
178 | 175 | blank >= 0 && blank < logProbs.size(-1),
|
179 | 176 | "blank must be within [0, num classes)");
|
180 | 177 |
|
181 |
| - TORCH_CHECK( |
182 |
| - logProbs.size(1) == at::max(inputLengths).item().toInt(), |
| 178 | + STD_TORCH_CHECK( |
| 179 | + logProbs.size(1) == torchaudio::util::max<int>(inputLengths), |
183 | 180 | "input length mismatch");
|
184 |
| - TORCH_CHECK( |
185 |
| - targets.size(1) == at::max(targetLengths).item().toInt(), |
| 181 | + STD_TORCH_CHECK( |
| 182 | + targets.size(1) == torchaudio::util::max<int>(targetLengths), |
186 | 183 | "target length mismatch");
|
187 | 184 |
|
188 | 185 | const auto B = logProbs.size(0);
|
189 | 186 | const auto T = logProbs.size(1);
|
190 |
| - auto paths = torch::zeros( |
191 |
| - {B, T}, |
192 |
| - torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); |
193 |
| - AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
194 |
| - logProbs.scalar_type(), "forced_align_impl", [&] { |
195 |
| - if (targets.scalar_type() == torch::kInt64) { |
196 |
| - forced_align_impl<scalar_t, torch::kInt64>( |
197 |
| - logProbs, targets, blank, paths); |
198 |
| - } else { |
199 |
| - forced_align_impl<scalar_t, torch::kInt32>( |
200 |
| - logProbs, targets, blank, paths); |
201 |
| - } |
202 |
| - }); |
| 187 | + Tensor paths = torch::stable::new_empty(targets, {B, T}); |
| 188 | + torch::stable::zero_(paths); |
| 189 | + |
| 190 | + switch (logProbs.scalar_type()) { |
| 191 | + case ScalarType::Double: { |
| 192 | + if (targets.scalar_type() == ScalarType::Long) { |
| 193 | + forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths); |
| 194 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 195 | + forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths); |
| 196 | + } else { |
| 197 | + STD_TORCH_CHECK(false, "unreachable"); |
| 198 | + } |
| 199 | + break; |
| 200 | + } |
| 201 | + case ScalarType::Float: { |
| 202 | + if (targets.scalar_type() == ScalarType::Long) { |
| 203 | + forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths); |
| 204 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 205 | + forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths); |
| 206 | + } else { |
| 207 | + STD_TORCH_CHECK(false, "unreachable"); |
| 208 | + } |
| 209 | + break; |
| 210 | + } |
| 211 | + case ScalarType::Half: { |
| 212 | + if (targets.scalar_type() == ScalarType::Long) { |
| 213 | + forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths); |
| 214 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 215 | + forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths); |
| 216 | + } else { |
| 217 | + STD_TORCH_CHECK(false, "unreachable"); |
| 218 | + } |
| 219 | + break; |
| 220 | + } |
| 221 | + default: { |
| 222 | + STD_TORCH_CHECK(false, "unreachable"); |
| 223 | + } |
| 224 | + }; |
| 225 | + |
203 | 226 | return std::make_tuple(paths, logProbs);
|
204 | 227 | }
|
205 | 228 |
|
206 |
| -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
207 |
| - m.impl("forced_align", &compute); |
| 229 | +void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { |
| 230 | + STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); |
| 231 | + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); |
| 232 | + std::tuple<Tensor, Tensor> res = compute( |
| 233 | + /*logProbs*/to<Tensor>(stack[0]), |
| 234 | + /*targets*/to<Tensor>(stack[1]), |
| 235 | + /*logit_lengths*/to<Tensor>(stack[2]), |
| 236 | + /*target_lengths*/to<Tensor>(stack[3]), |
| 237 | + /*blank*/float(to<int64_t>(stack[4]))); |
| 238 | + stack[0] = from(std::get<0>(res)); |
| 239 | + stack[1] = from(std::get<1>(res)); |
| 240 | +} |
| 241 | + |
| 242 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
| 243 | + m.impl("forced_align", &boxed_forced_align_cpu); |
208 | 244 | }
|
209 | 245 |
|
210 | 246 | } // namespace cpu
|
|
0 commit comments