diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index 7988099eb1..749596f34a 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -1,26 +1,26 @@ -#include +#include +#include +#include #include -#include -#include -#include -#include - -using namespace std; namespace torchaudio { namespace alignment { namespace cpu { + +using torch::headeronly::ScalarType; +using torch::stable::Tensor; + // Inspired from // https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp -template +template void forced_align_impl( - const torch::Tensor& logProbs, - const torch::Tensor& targets, + const Tensor& logProbs, + const Tensor& targets, const int64_t blank, - torch::Tensor& paths) { + Tensor& paths) { const scalar_t kNegInfinity = -std::numeric_limits::infinity(); using target_t = typename std:: - conditional::type; + conditional::type; const auto batchIndex = 0; // TODO: support batch version and use the real batch index const auto T = logProbs.size(1); @@ -36,17 +36,16 @@ void forced_align_impl( for (int i = 0; i < T * S; i++) { backPtr_a[i] = -1; } - - auto logProbs_a = logProbs.accessor(); - auto targets_a = targets.accessor(); - auto paths_a = paths.accessor(); + auto logProbs_a = torchaudio::stable::accessor(logProbs); + auto targets_a = torchaudio::stable::accessor(targets); + auto paths_a = torchaudio::stable::accessor(paths); auto R = 0; for (auto i = 1; i < L; i++) { if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) { ++R; } } - TORCH_CHECK( + STD_TORCH_CHECK( T >= L + R, "targets length is too long for CTC. Found log_probs length: ", T, @@ -138,73 +137,89 @@ void forced_align_impl( delete[] backPtr_a; } -std::tuple compute( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const torch::Tensor& inputLengths, - const torch::Tensor& targetLengths, +std::tuple compute( + const Tensor& logProbs, + const Tensor& targets, + const Tensor& inputLengths, + const Tensor& targetLengths, const int64_t blank) { - TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); - TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); - TORCH_CHECK( - logProbs.device() == targets.device(), - "log_probs and targets need to be on the same device"); - TORCH_CHECK( - logProbs.dtype() == torch::kFloat64 || - logProbs.dtype() == torch::kFloat32 || - logProbs.dtype() == torch::kFloat16, + STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); + STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); + STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor"); + STD_TORCH_CHECK( + targetLengths.is_cpu(), "target_lengths must be a CPU tensor"); + STD_TORCH_CHECK( + logProbs.scalar_type() == ScalarType::Double || + logProbs.scalar_type() == ScalarType::Float || + logProbs.scalar_type() == ScalarType::Half, "log_probs must be float64, float32 or float16 (half) type"); - TORCH_CHECK( - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, + STD_TORCH_CHECK( + targets.scalar_type() == ScalarType::Int || + targets.scalar_type() == ScalarType::Long, "targets must be int32 or int64 type"); - TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + STD_TORCH_CHECK( logProbs.dim() == 3, "log_probs must be 3-D (batch_size, input length, num classes)"); - TORCH_CHECK( + STD_TORCH_CHECK( targets.dim() == 2, "targets must be 2-D (batch_size, target length,)"); - TORCH_CHECK( + STD_TORCH_CHECK( inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)"); - TORCH_CHECK( + STD_TORCH_CHECK( targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)"); - TORCH_CHECK( + STD_TORCH_CHECK( logProbs.size(0) == 1, "The batch dimension for log_probs must be 1 at the current version.") - TORCH_CHECK( + STD_TORCH_CHECK( targets.size(0) == 1, "The batch dimension for targets must be 1 at the current version.") - TORCH_CHECK( + STD_TORCH_CHECK( blank >= 0 && blank < logProbs.size(-1), "blank must be within [0, num classes)"); - TORCH_CHECK( - logProbs.size(1) == at::max(inputLengths).item().toInt(), + STD_TORCH_CHECK( + logProbs.size(1) == torchaudio::util::max(inputLengths), "input length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(targetLengths).item().toInt(), + STD_TORCH_CHECK( + targets.size(1) == torchaudio::util::max(targetLengths), "target length mismatch"); const auto B = logProbs.size(0); const auto T = logProbs.size(1); - auto paths = torch::zeros( - {B, T}, - torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}); + + STABLE_DISPATCH_FLOATING_TYPES_AND_HALF( logProbs.scalar_type(), "forced_align_impl", [&] { - if (targets.scalar_type() == torch::kInt64) { - forced_align_impl( + if (targets.scalar_type() == ScalarType::Long) { + forced_align_impl( logProbs, targets, blank, paths); } else { - forced_align_impl( + forced_align_impl( logProbs, targets, blank, paths); } }); return std::make_tuple(paths, logProbs); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("forced_align", &compute); +void boxed_forced_align_cpu( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); + std::tuple res = compute( + /*logProbs*/ to(stack[0]), + /*targets*/ to(stack[1]), + /*logit_lengths*/ to(stack[2]), + /*target_lengths*/ to(stack[3]), + /*blank*/ float(to(stack[4]))); + stack[0] = from(std::get<0>(res)); + stack[1] = from(std::get<1>(res)); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("forced_align", &boxed_forced_align_cpu); } } // namespace cpu diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index a78f694b51..7d7fb86bc5 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -1,11 +1,11 @@ -#include -#include -#include -#include -#include +#include +#include +#include +#include + #include +#include -using namespace torch::indexing; namespace { constexpr int kNumThreads = 1024; // Number of threads to run CUDA kernel in parallel. @@ -16,11 +16,15 @@ constexpr int kBackPtrBufferSize = namespace torchaudio { namespace alignment { namespace gpu { + +using torch::stable::Tensor; +using torch::headeronly::ScalarType; + template __global__ void falign_cuda_step_kernel( - const torch::PackedTensorAccessor32 + const torchaudio::stable::PackedTensorAccessor32 logProbs_a, - const torch::PackedTensorAccessor32 + const torchaudio::stable::PackedTensorAccessor32 targets_a, const int T, const int L, @@ -31,9 +35,9 @@ __global__ void falign_cuda_step_kernel( int start, int end, int backPtrBufferLen, - torch::PackedTensorAccessor32 + torchaudio::stable::PackedTensorAccessor32 alphas_a, - torch::PackedTensorAccessor32 + torchaudio::stable::PackedTensorAccessor32 backPtrBuffer_a) { scalar_t kNegInfinity = -std::numeric_limits::infinity(); const int batchIndex = @@ -109,52 +113,44 @@ __global__ void falign_cuda_step_kernel( } } -template +template void forced_align_impl( - const torch::Tensor& logProbs, - const torch::Tensor& targets, + const Tensor& logProbs, + const Tensor& targets, const int64_t blank, - torch::Tensor& paths) { + Tensor& paths) { auto defaultStream = at::cuda::getCurrentCUDAStream(); auto cpuDataTranferStream = at::cuda::getStreamFromPool(); const scalar_t kNegInfinity = -std::numeric_limits::infinity(); using target_t = typename std:: - conditional::type; - auto paths_a = paths.accessor(); + conditional::type; + auto paths_a = torchaudio::stable::accessor(paths); const int batchIndex = 0; // TODO: support batch version and use the real batch index const int T = logProbs.size(1); // num frames const int N = logProbs.size(2); // alphabet size const int L = targets.size(1); // label length const int S = 2 * L + 1; - auto targetsCpu = targets.to(torch::kCPU); + + auto targetsCpu = torchaudio::stable::cpu(targets); // backPtrBuffer stores the index offset fthe best path at current position // We copy the values to CPU after running every kBackPtrBufferSize of // frames. - torch::Tensor backPtrBuffer = - torch::empty( - {min(kBackPtrBufferSize, T), S}, - torch::TensorOptions().dtype(torch::kInt8).device(logProbs.device())) - .contiguous() - .fill_(-1); - torch::Tensor backPtrCpu = - torch::empty( - {T, S}, - torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU)) - .contiguous() - .fill_(-1); + Tensor backPtrBuffer = torch::stable::new_empty(logProbs, {min(kBackPtrBufferSize, T), S}, ScalarType::Char); + torch::stable::fill_(backPtrBuffer, -1); + + Tensor backPtrCpu = torch::stable::new_empty(targetsCpu, {T, S}, ScalarType::Char); + torch::stable::fill_(backPtrCpu, -1); + // we store only two time frames for alphas // alphas for compute current timeframe can be computed only from previous // time frame. - torch::Tensor alphas = torch::empty( - {2, S}, - torch::TensorOptions() - .dtype(logProbs.dtype()) - .device(logProbs.device())) - .fill_(kNegInfinity); + Tensor alphas = torch::stable::new_empty(logProbs, {2, S}); + torch::stable::fill_(alphas, kNegInfinity); + // CPU accessors - auto targetsCpu_a = targetsCpu.accessor(); - auto backPtrCpu_a = backPtrCpu.accessor(); + auto targetsCpu_a = torchaudio::stable::accessor(targetsCpu); + auto backPtrCpu_a = torchaudio::stable::accessor(backPtrCpu); // count the number of repeats in label int R = 0; for (int i = 1; i < L; ++i) { @@ -162,7 +158,7 @@ void forced_align_impl( ++R; } } - TORCH_CHECK( + STD_TORCH_CHECK( T >= L + R, "targets length is too long for CTC. Found log_probs length: ", T, @@ -173,7 +169,7 @@ void forced_align_impl( int start = (T - (L + R)) > 0 ? 0 : 1; int end = (S == 1) ? 1 : 2; int backPtrBufferLen = 0; - torch::Tensor bufferCopy; + Tensor bufferCopy; for (int t = 0; t < T; ++t) { if (t > 0) { if (T - t <= L + R) { @@ -195,8 +191,8 @@ void forced_align_impl( } falign_cuda_step_kernel <<<1, kNumThreads, 0, defaultStream>>>( - logProbs.packed_accessor32(), - targets.packed_accessor32(), + torchaudio::stable::packed_accessor32(logProbs), + torchaudio::stable::packed_accessor32(targets), T, L, N, @@ -206,15 +202,15 @@ void forced_align_impl( start, end, backPtrBufferLen, - alphas.packed_accessor32(), - backPtrBuffer - .packed_accessor32()); + torchaudio::stable::packed_accessor32(alphas), + torchaudio::stable::packed_accessor32(backPtrBuffer)); C10_CUDA_KERNEL_LAUNCH_CHECK(); ++backPtrBufferLen; if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) { cpuDataTranferStream.synchronize(); // GPU -> GPU copy - bufferCopy = backPtrBuffer.clone().contiguous(); + bufferCopy = torchaudio::stable::clone(backPtrBuffer); + STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()") defaultStream.synchronize(); at::cuda::setCurrentCUDAStream(cpuDataTranferStream); // Copy ASYNC from GPU to CPU @@ -231,8 +227,9 @@ void forced_align_impl( } } cpuDataTranferStream.synchronize(); - torch::Tensor alphasCpu = alphas.to(torch::kCPU); - auto alphasCpu_a = alphasCpu.accessor(); + + auto alphasCpu = torchaudio::stable::cpu(alphas); + auto alphasCpu_a = torchaudio::stable::accessor(alphasCpu); int curIdxOffset = ((T - 1) % 2); int ltrIdx = alphasCpu_a[curIdxOffset][S - 1] > alphasCpu_a[curIdxOffset][S - 2] @@ -246,75 +243,90 @@ void forced_align_impl( } } -std::tuple compute( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const torch::Tensor& inputLengths, - const torch::Tensor& targetLengths, +std::tuple compute( + const Tensor& logProbs, + const Tensor& targets, + const Tensor& inputLengths, + const Tensor& targetLengths, const int64_t blank) { - TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor"); - TORCH_CHECK(targets.is_cuda(), "targets must be a CUDA tensor"); - TORCH_CHECK( - logProbs.device() == targets.device(), + + STD_TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor"); + STD_TORCH_CHECK(targets.is_cuda(), "targets must be a CUDA tensor"); + STD_TORCH_CHECK( + logProbs.get_device_index() == targets.get_device_index(), "log_probs and targets need to be on the same device"); - TORCH_CHECK( - logProbs.dtype() == torch::kFloat64 || - logProbs.dtype() == torch::kFloat32 || - logProbs.dtype() == torch::kFloat16, + STD_TORCH_CHECK(inputLengths.is_cuda(), "input_lengths must be a CUDA tensor"); + STD_TORCH_CHECK(targetLengths.is_cuda(), "target_lengths must be a CUDA tensor"); + STD_TORCH_CHECK( + logProbs.scalar_type() == ScalarType::Double || + logProbs.scalar_type() == ScalarType::Float || + logProbs.scalar_type() == ScalarType::Half, "log_probs must be float64, float32 or float16 (half) type"); - TORCH_CHECK( - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, + STD_TORCH_CHECK( + targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long, "targets must be int32 or int64 type"); - TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + STD_TORCH_CHECK( logProbs.dim() == 3, "log_probs must be 3-D (batch_size, input length, num classes)"); - TORCH_CHECK( + STD_TORCH_CHECK( targets.dim() == 2, "targets must be 2-D (batch_size, target length,)"); - TORCH_CHECK( + STD_TORCH_CHECK( inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)"); - TORCH_CHECK( + STD_TORCH_CHECK( targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)"); - TORCH_CHECK( + STD_TORCH_CHECK( logProbs.size(0) == 1, "The batch dimension for log_probs must be 1 at the current version.") - TORCH_CHECK( + STD_TORCH_CHECK( targets.size(0) == 1, "The batch dimension for targets must be 1 at the current version.") - TORCH_CHECK( + STD_TORCH_CHECK( blank >= 0 && blank < logProbs.size(-1), "blank must be within [0, num classes)"); - TORCH_CHECK( - logProbs.size(1) == at::max(inputLengths).item().toInt(), + STD_TORCH_CHECK(logProbs.size(1) == torchaudio::util::max(inputLengths), "input length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(targetLengths).item().toInt(), + STD_TORCH_CHECK( + targets.size(1) == torchaudio::util::max(targetLengths), "target length mismatch"); auto B = logProbs.size(0); auto T = logProbs.size(1); // num frames - auto paths = torch::zeros( - {B, T}, - torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype())); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + + Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torchaudio::stable::cpu_device()); + + STABLE_DISPATCH_FLOATING_TYPES_AND_HALF( logProbs.scalar_type(), "forced_align_impl", [&] { - if (targets.scalar_type() == torch::kInt64) { - forced_align_impl( + if (targets.scalar_type() == ScalarType::Long) { + forced_align_impl( logProbs, targets, blank, paths); } else { - forced_align_impl( + forced_align_impl( logProbs, targets, blank, paths); } }); - return std::make_tuple( - paths.to(logProbs.device()), - logProbs); + + Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index()); + return std::make_tuple(pathsCuda, logProbs); +} + +void boxed_forced_align_gpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); + std::tuple res = compute( + /*logProbs*/to(stack[0]), + /*targets*/to(stack[1]), + /*logit_lengths*/to(stack[2]), + /*target_lengths*/to(stack[3]), + /*blank*/float(to(stack[4]))); + stack[0] = from(std::get<0>(res)); + stack[1] = from(std::get<1>(res)); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("forced_align", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("forced_align", &boxed_forced_align_gpu); } } // namespace gpu diff --git a/src/libtorchaudio/stable/Device.h b/src/libtorchaudio/stable/Device.h new file mode 100644 index 0000000000..a46a0ccc29 --- /dev/null +++ b/src/libtorchaudio/stable/Device.h @@ -0,0 +1,45 @@ +#pragma once + +/* + This header files provides torchaudio::stable::Device struct that is + torch::stable::Tensor-compatible analogus of c10::Device defined + c10/core/Device.h. + + TODO: remove this header file when torch::stable provides all + features implemented here. +*/ + +#include + +namespace torchaudio::stable { + +using DeviceType = int32_t; +using torch::stable::accelerator::DeviceIndex; + +struct Device { + Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) { + // TODO: validate(); + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the optional index. + DeviceIndex index() const noexcept { + return index_; + } + + private: + DeviceType type_; + DeviceIndex index_ = -1; +}; + +// A convinience function, not a part of torch::stable +inline Device cpu_device() { + Device d(aoti_torch_device_type_cpu(), 0); + return d; +} + +} // namespace torchaudio::stable diff --git a/src/libtorchaudio/stable/TensorAccessor.h b/src/libtorchaudio/stable/TensorAccessor.h new file mode 100644 index 0000000000..83ba9a0e5c --- /dev/null +++ b/src/libtorchaudio/stable/TensorAccessor.h @@ -0,0 +1,275 @@ +#pragma once +/* + This header files provides torchaudio::stable::TensorAccessor + templates that are torch::stable::Tensor-compatible analogus of + at::TensorAccessor defined in ATen/core/TensorAccessor.h. + + TODO: remove this header file when torch::stable provides all + features implemented here. +*/ + +// #include + +#include +#include + +namespace torchaudio::stable { + +template +struct DefaultPtrTraits { + typedef T* PtrType; +}; + +#if defined(__CUDACC__) || defined(__HIPCC__) +template +struct RestrictPtrTraits { + typedef T* __restrict__ PtrType; +}; +#endif + +template < + typename T, + size_t N, + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class TensorAccessorBase { + public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessorBase( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : data_(data_) /*, sizes_(sizes_), strides_(strides_)*/ { + // Originally, TensorAccessor is a view of sizes and strides as + // these are ArrayRef instances. Until torch::stable supports + // ArrayRef-like features, we store copies of sizes and strides: + for (auto i = 0; i < N; ++i) { + this->sizes_[i] = sizes_[i]; + this->strides_[i] = strides_[i]; + } + } + + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } + + protected: + PtrType data_; + /* + const index_t* sizes_; + const index_t* strides_; + */ + // NOLINTNEXTLINE(*c-arrays*) + index_t sizes_[N]; + // NOLINTNEXTLINE(*c-arrays*) + index_t strides_[N]; +}; + +template < + typename T, + size_t N, + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class TensorAccessor : public TensorAccessorBase { + public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_, sizes_, strides_) {} + + C10_HOST_DEVICE TensorAccessor operator[]( + index_t i) { + return TensorAccessor( + this->data_ + this->strides_[0] * i, + this->sizes_ + 1, + this->strides_ + 1); + } + + C10_HOST_DEVICE const TensorAccessor operator[]( + index_t i) const { + return TensorAccessor( + this->data_ + this->strides_[0] * i, + this->sizes_ + 1, + this->strides_ + 1); + } +}; + +template class PtrTraits, typename index_t> +class TensorAccessor + : public TensorAccessorBase { + public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_, sizes_, strides_) {} + C10_HOST_DEVICE T& operator[](index_t i) { + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + return this->data_[this->strides_[0] * i]; + } + C10_HOST_DEVICE const T& operator[](index_t i) const { + return this->data_[this->strides_[0] * i]; + } +}; + +template < + typename T, + size_t N, + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class GenericPackedTensorAccessorBase { + public: + typedef typename PtrTraits::PtrType PtrType; + C10_HOST GenericPackedTensorAccessorBase( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : data_(data_) { + std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); + std::copy(strides_, strides_ + N, std::begin(this->strides_)); + } + + template < + typename source_index_t, + class = std::enable_if_t>> + C10_HOST GenericPackedTensorAccessorBase( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : data_(data_) { + for (auto i = 0; i < N; ++i) { + this->sizes_[i] = sizes_[i]; + this->strides_[i] = strides_[i]; + } + } + + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } + + protected: + PtrType data_; + // NOLINTNEXTLINE(*c-arrays*) + index_t sizes_[N]; + // NOLINTNEXTLINE(*c-arrays*) + index_t strides_[N]; + C10_HOST void bounds_check_(index_t i) const { + STD_TORCH_CHECK( + 0 <= i && i < index_t{N}, + "Index ", + i, + " is not within bounds of a tensor of dimension ", + N); + } +}; + +template < + typename T, + size_t N, + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class GenericPackedTensorAccessor + : public GenericPackedTensorAccessorBase { + public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : GenericPackedTensorAccessorBase( + data_, + sizes_, + strides_) {} + + // if index_t is not int64_t, we want to have an int64_t constructor + template < + typename source_index_t, + class = std::enable_if_t>> + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : GenericPackedTensorAccessorBase( + data_, + sizes_, + strides_) {} + + C10_DEVICE TensorAccessor operator[]( + index_t i) { + index_t* new_sizes = this->sizes_ + 1; + index_t* new_strides = this->strides_ + 1; + return TensorAccessor( + this->data_ + this->strides_[0] * i, new_sizes, new_strides); + } + + C10_DEVICE const TensorAccessor operator[]( + index_t i) const { + const index_t* new_sizes = this->sizes_ + 1; + const index_t* new_strides = this->strides_ + 1; + return TensorAccessor( + this->data_ + this->strides_[0] * i, new_sizes, new_strides); + } +}; + +template class PtrTraits, typename index_t> +class GenericPackedTensorAccessor + : public GenericPackedTensorAccessorBase { + public: + typedef typename PtrTraits::PtrType PtrType; + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : GenericPackedTensorAccessorBase( + data_, + sizes_, + strides_) {} + + template < + typename source_index_t, + class = std::enable_if_t>> + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : GenericPackedTensorAccessorBase( + data_, + sizes_, + strides_) {} + + C10_DEVICE T& operator[](index_t i) { + return this->data_[this->strides_[0] * i]; + } + C10_DEVICE const T& operator[](index_t i) const { + return this->data_[this->strides_[0] * i]; + } +}; + +template < + typename T, + size_t N, + template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor32 = + GenericPackedTensorAccessor; + +template < + typename T, + size_t N, + template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor64 = + GenericPackedTensorAccessor; + +} // namespace torchaudio::stable diff --git a/src/libtorchaudio/stable/dispatch.h b/src/libtorchaudio/stable/dispatch.h new file mode 100644 index 0000000000..3ec96fbe46 --- /dev/null +++ b/src/libtorchaudio/stable/dispatch.h @@ -0,0 +1,91 @@ +#pragma once +/* + This header files provides CPP macros + + STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) + + that are torch::stable::Tensor-compatible analogous of + the following macros: + + AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) + + respectively. + + TODO: remove this header file when torch::stable provides all + features implemented here. +*/ + +#include +#include + +namespace torchaudio::stable { + +using torch::headeronly::ScalarType; + +namespace impl { + +inline const char* toString(ScalarType t) { +#define DEFINE_CASE(_, name) \ + case ScalarType::name: \ + return #name; + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + default: + return "UNKNOWN_SCALAR"; + } +#undef DEFINE_CASE +} + +template +struct ScalarTypeToCPPType; + +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ + }; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) + +#undef SPECIALIZE_ScalarTypeToCPPType + +template +using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType::type; + +} // namespace impl + +} // namespace torchaudio::stable + +#define STABLE_DISPATCH_CASE(enum_type, ...) \ + case enum_type: { \ + using scalar_t [[maybe_unused]] = \ + torchaudio::stable::impl::ScalarTypeToCPPTypeT; \ + return __VA_ARGS__(); \ + } + +#define STABLE_DISPATCH_SWITCH(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + constexpr const char* at_dispatch_name = NAME; \ + switch (the_type) { \ + __VA_ARGS__ \ + default: \ + STD_TORCH_CHECK( \ + false, \ + '"', \ + at_dispatch_name, \ + "\" not implemented for '", \ + torchaudio::stable::impl::toString(the_type), \ + "'"); \ + } \ + }() + +#define STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \ + STABLE_DISPATCH_CASE(ScalarType::Double, __VA_ARGS__) \ + STABLE_DISPATCH_CASE(ScalarType::Float, __VA_ARGS__) \ + STABLE_DISPATCH_CASE(ScalarType::Half, __VA_ARGS__) + +#define STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + STABLE_DISPATCH_SWITCH( \ + TYPE, NAME, STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__)) diff --git a/src/libtorchaudio/stable/ops.h b/src/libtorchaudio/stable/ops.h new file mode 100644 index 0000000000..d09aea14fa --- /dev/null +++ b/src/libtorchaudio/stable/ops.h @@ -0,0 +1,298 @@ +#pragma once + +/* + This header files provides torchaudio::stable operations that are + torch::stable::Tensor-compatible analogus operations defined in + ATen/core/TensorBase.h and elsewhere. + + TODO: remove this header file when torch::stable provides all + features implemented here. +*/ + +#include +#include +#include + +#ifdef USE_CUDA +#include +#include +#endif + +using torch::stable::Tensor; + +namespace torchaudio::stable { + +using Layout = int32_t; + +// TODO: When sizes and strides are implemented in torch::stable, +// eliminate sizes and strides function below. +inline std::vector sizes(const Tensor& t) { + int64_t* ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(t.get(), &ptr)); + std::vector r(ptr, ptr + t.dim()); + return r; +} + +inline std::vector strides(const Tensor& t) { + int64_t* ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(t.get(), &ptr)); + std::vector r(ptr, ptr + t.dim()); + return r; +} + +// TODO: When https://github.com/pytorch/pytorch/pull/161891 lands, +// eliminate mutable_data_ptr and const_data_ptr templates. +#define aoti_torch_get_mutable_data_ptr aoti_torch_get_data_ptr +#define aoti_torch_get_const_data_ptr aoti_torch_get_data_ptr +template +T* mutable_data_ptr(const Tensor& t) { + void* data_ptr{}; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_mutable_data_ptr(t.get(), &data_ptr)); + return reinterpret_cast(data_ptr); +} + +template +const T* const_data_ptr(const Tensor& t) { + const void* data_ptr{}; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_const_data_ptr(t.get(), const_cast(&data_ptr))); + return reinterpret_cast(data_ptr); +} + +// TODO: When accessor is implemented in torch::stable, eliminate +// accessor template below. + +template +torchaudio::stable::TensorAccessor accessor(const Tensor& t) { + static_assert( + N > 0, + "accessor is used for indexing tensor, for scalars use *data_ptr()"); + STD_TORCH_CHECK( + t.dim() == N, + "TensorAccessor expected ", + N, + " dims but tensor has ", + t.dim()); + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = const_data_ptr(t); + } else { + ptr = mutable_data_ptr(t); + } + auto sizes_ = sizes(t); + auto strides_ = strides(t); + return torchaudio::stable::TensorAccessor( + ptr, sizes_.data(), strides_.data()); +} + +// TODO: move to TensorAccessor.h? +template < + typename T, + size_t N, + template + class PtrTraits = torchaudio::stable::DefaultPtrTraits, + typename index_t = int64_t> +torchaudio::stable::GenericPackedTensorAccessor +generic_packed_accessor(const Tensor& t) { + static_assert( + N > 0, + "accessor is used for indexing tensor, for scalars use *data_ptr()"); + STD_TORCH_CHECK( + t.dim() == N, + "TensorAccessor expected ", + N, + " dims but tensor has ", + t.dim()); + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = const_data_ptr(t); + } else { + ptr = mutable_data_ptr(t); + } + auto sizes_ = sizes(t); + auto strides_ = strides(t); + return torchaudio::stable:: + GenericPackedTensorAccessor( + static_cast::PtrType>(ptr), + sizes_.data(), + strides_.data()); +} +// template class PtrTraits = +// torchaudio::stable::DefaultPtrTraits, typename index_t = int64_t> +// GenericPackedTensorAccessor generic_packed_accessor(const Tensor& t) && +// = delete; + +template < + typename T, + size_t N, + template + class PtrTraits = torchaudio::stable::DefaultPtrTraits> +torchaudio::stable::PackedTensorAccessor32 packed_accessor32( + const Tensor& t) { + STD_TORCH_CHECK( + t.numel() <= static_cast(std::numeric_limits::max()), + "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); + return generic_packed_accessor(t); +} +// template class PtrTraits = +// torchaudio::stable::DefaultPtrTraits> PackedTensorAccessor32 +// packed_accessor32(const Tensor& t) && = delete; + +template < + typename T, + size_t N, + template + class PtrTraits = torchaudio::stable::DefaultPtrTraits> +torchaudio::stable::PackedTensorAccessor64 packed_accessor64( + const Tensor& t) { + return generic_packed_accessor(); +} +// template class PtrTraits = +// DefaultPtrTraits> PackedTensorAccessor64 +// packed_accessor64(const Tensor& t) && = delete; + +// TODO: When https://github.com/pytorch/pytorch/pull/161895 lands, eliminate +// copy_ function below. +inline Tensor copy_( + Tensor& self, + const Tensor& src, + std::optional non_blocking = std::nullopt) { + const auto num_args = 3; + std::array stack{ + from(self), from(src), from(non_blocking.value_or(false))}; + TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::copy_", "", stack.data())); + return to(stack[0]); +} + +// TODO: When cpu is implemented in torch::stable, eliminate +// cpu function below. +inline Tensor cpu(const Tensor& self) { + auto sizes_ = sizes(self); + auto cpu_type = aoti_torch_device_type_cpu(); + int32_t dtype; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype)); + int32_t layout; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); + AtenTensorHandle ret0; + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty( + self.get(), + sizes_.data(), + static_cast(self.dim()), + &dtype, + &layout, + &cpu_type, + 0, + nullptr, // pin_memory (nullptr for default) + &ret0)); + auto result = Tensor(ret0); + copy_(result, self); + return result; +} + +// TODO: +inline Tensor cuda(const Tensor& self, int32_t cuda_index) { + auto sizes_ = sizes(self); + auto cuda_type = aoti_torch_device_type_cuda(); + int32_t dtype; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype)); + int32_t layout; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); + AtenTensorHandle ret0; + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty( + self.get(), + sizes_.data(), + static_cast(self.dim()), + &dtype, + &layout, + &cuda_type, + cuda_index, + nullptr, // pin_memory (nullptr for default) + &ret0)); + auto result = Tensor(ret0); + copy_(result, self); + return result; +} + +// TODO: remove when torch::stable provides new_zeros +inline Tensor new_zeros( + const Tensor& self, + std::vector size, + std::optional dtype = std::nullopt, + std::optional layout = std::nullopt, + std::optional device = std::nullopt, + std::optional pin_memory = std::nullopt) { + int32_t target_dtype{}; + if (dtype.has_value()) { + target_dtype = to(from(dtype.value())); + } else { + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype)); + } + + Layout layout_; + if (layout.has_value()) { + layout_ = layout.value(); + } else { + TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout_)); + } + + DeviceType device_type; + DeviceIndex device_index = 0; + if (device.has_value()) { + auto device_ = device.value(); + device_type = device_.type(); + device_index = device_.index(); + } else { + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_type(self.get(), &device_type)); + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(self.get(), &device_index)); + } + + // TODO: pin_memory + + AtenTensorHandle ret0; + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty( + self.get(), + size.data(), + static_cast(size.size()), + &target_dtype, + &layout_, + &device_type, + device_index, + nullptr, // pin_memory (nullptr for default) + &ret0)); + + auto result = Tensor(ret0); + torch::stable::zero_(result); + return result; +} + +// TODO: https://github.com/pytorch/pytorch/pull/161896 +inline Tensor clone(const Tensor& self) { + AtenTensorHandle ret = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_clone(self.get(), &ret)); + return Tensor(ret); +} + +// An analog of item template function defined in +// ATen/templates/TensorBody.h +template +T item(const Tensor& self) { + STD_TORCH_CHECK( + self.numel() == 1, "item requires single element tensor input"); + if (self.is_cpu()) { + return torchaudio::stable::const_data_ptr(self)[0]; +#ifdef USE_CUDA + } else if (self.is_cuda()) { + T value; + C10_CUDA_CHECK(cudaMemcpyAsync( + &value, self.data_ptr(), sizeof(T), cudaMemcpyDeviceToHost)); + return value; +#endif + } else { + STD_TORCH_CHECK(false, "unreachable"); // not implemented + } +} + +} // namespace torchaudio::stable diff --git a/src/libtorchaudio/utils.cpp b/src/libtorchaudio/utils.cpp index 4789f0c0a8..a6000756ca 100644 --- a/src/libtorchaudio/utils.cpp +++ b/src/libtorchaudio/utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #ifdef USE_CUDA #include diff --git a/src/libtorchaudio/utils.h b/src/libtorchaudio/utils.h index ffab65cd38..dd5ca6de95 100644 --- a/src/libtorchaudio/utils.h +++ b/src/libtorchaudio/utils.h @@ -1,8 +1,21 @@ #pragma once -#include + +// TODO: replace the include libtorchaudio/stable/ops.h with +// torch/stable/ops.h when torch::stable provides all required +// features (torch::stable::item or similar): +#include namespace torchaudio { + +namespace util { +template +T max(const torch::stable::Tensor& t) { + return torchaudio::stable::item(torch::stable::amax(t, {})); +} +} // namespace util + bool is_rir_available(); bool is_align_available(); std::optional cuda_version(); + } // namespace torchaudio