diff --git a/src/libtorchaudio/forced_align/compute.cpp b/src/libtorchaudio/forced_align/compute.cpp index 40cbcc8dd7..b1e90603e5 100644 --- a/src/libtorchaudio/forced_align/compute.cpp +++ b/src/libtorchaudio/forced_align/compute.cpp @@ -1,19 +1,10 @@ -#include -#include +#include -std::tuple forced_align( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const torch::Tensor& inputLengths, - const torch::Tensor& targetLengths, - const int64_t blank) { - static auto op = torch::Dispatcher::singleton() - .findSchemaOrThrow("torchaudio::forced_align", "") - .typed(); - return op.call(logProbs, targets, inputLengths, targetLengths, blank); -} - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( - "forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)"); + "forced_align(Tensor log_probs," + "Tensor targets," + "Tensor input_lengths," + "Tensor target_lengths," + "int blank) -> (Tensor, Tensor)"); } diff --git a/src/libtorchaudio/forced_align/compute.h b/src/libtorchaudio/forced_align/compute.h index aaae9b086c..6f70f09bee 100644 --- a/src/libtorchaudio/forced_align/compute.h +++ b/src/libtorchaudio/forced_align/compute.h @@ -1,10 +1 @@ #pragma once - -#include - -std::tuple forced_align( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const torch::Tensor& inputLengths, - const torch::Tensor& targetLengths, - const int64_t blank); diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index fc702dc3b7..5adb822f66 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -38,9 +38,9 @@ void forced_align_impl( for (int i = 0; i < T * S; i++) { backPtr_a[i] = -1; } - auto logProbs_a = torchaudio::stable::accessor(logProbs); - auto targets_a = torchaudio::stable::accessor(targets); - auto paths_a = torchaudio::stable::accessor(paths); + auto logProbs_a = torchaudio::accessor(logProbs); + auto targets_a = torchaudio::accessor(targets); + auto paths_a = torchaudio::accessor(paths); auto R = 0; for (auto i = 1; i < L; i++) { if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) { @@ -147,10 +147,10 @@ template const auto forced_align_int_impl = forced_align_impl; std::tuple compute( - const Tensor& logProbs, - const Tensor& targets, - const Tensor& inputLengths, - const Tensor& targetLengths, + Tensor logProbs, + Tensor targets, + Tensor inputLengths, + Tensor targetLengths, const int64_t blank) { STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); @@ -224,24 +224,8 @@ std::tuple compute( return std::make_tuple(paths, logProbs); } -void boxed_forced_align_cpu( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { - STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); - STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); - std::tuple res = compute( - /*logProbs*/ torch::stable::detail::to(stack[0]), - /*targets*/ torch::stable::detail::to(stack[1]), - /*logit_lengths*/ torch::stable::detail::to(stack[2]), - /*target_lengths*/ torch::stable::detail::to(stack[3]), - /*blank*/ float(torch::stable::detail::to(stack[4]))); - stack[0] = torch::stable::detail::from(std::get<0>(res)); - stack[1] = torch::stable::detail::from(std::get<1>(res)); -} - STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("forced_align", &boxed_forced_align_cpu); + m.impl("forced_align", TORCH_BOX(&compute)); } } // namespace cpu diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index e7430b93df..41f86914a6 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -23,9 +22,9 @@ using torch::headeronly::ScalarType; template __global__ void falign_cuda_step_kernel( - const torchaudio::stable::PackedTensorAccessor32 + const torchaudio::PackedTensorAccessor32 logProbs_a, - const torchaudio::stable::PackedTensorAccessor32 + const torchaudio::PackedTensorAccessor32 targets_a, const int T, const int L, @@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel( int start, int end, int backPtrBufferLen, - torchaudio::stable::PackedTensorAccessor32 + torchaudio::PackedTensorAccessor32 alphas_a, - torchaudio::stable::PackedTensorAccessor32 + torchaudio::PackedTensorAccessor32 backPtrBuffer_a) { scalar_t kNegInfinity = -std::numeric_limits::infinity(); const int batchIndex = @@ -125,7 +124,7 @@ void forced_align_impl( const scalar_t kNegInfinity = -std::numeric_limits::infinity(); using target_t = typename std:: conditional::type; - auto paths_a = torchaudio::stable::accessor(paths); + auto paths_a = torchaudio::accessor(paths); const int batchIndex = 0; // TODO: support batch version and use the real batch index const int T = logProbs.size(1); // num frames @@ -150,8 +149,8 @@ void forced_align_impl( torch::stable::fill_(alphas, kNegInfinity); // CPU accessors - auto targetsCpu_a = torchaudio::stable::accessor(targetsCpu); - auto backPtrCpu_a = torchaudio::stable::accessor(backPtrCpu); + auto targetsCpu_a = torchaudio::accessor(targetsCpu); + auto backPtrCpu_a = torchaudio::accessor(backPtrCpu); // count the number of repeats in label int R = 0; for (int i = 1; i < L; ++i) { @@ -192,8 +191,8 @@ void forced_align_impl( } falign_cuda_step_kernel <<<1, kNumThreads, 0, defaultStream>>>( - torchaudio::stable::packed_accessor32(logProbs), - torchaudio::stable::packed_accessor32(targets), + torchaudio::packed_accessor32(logProbs), + torchaudio::packed_accessor32(targets), T, L, N, @@ -203,8 +202,8 @@ void forced_align_impl( start, end, backPtrBufferLen, - torchaudio::stable::packed_accessor32(alphas), - torchaudio::stable::packed_accessor32(backPtrBuffer)); + torchaudio::packed_accessor32(alphas), + torchaudio::packed_accessor32(backPtrBuffer)); C10_CUDA_KERNEL_LAUNCH_CHECK(); ++backPtrBufferLen; if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) { @@ -228,9 +227,8 @@ void forced_align_impl( } } cpuDataTranferStream.synchronize(); - auto alphasCpu = torchaudio::stable::cpu(alphas); - auto alphasCpu_a = torchaudio::stable::accessor(alphasCpu); + auto alphasCpu_a = torchaudio::accessor(alphasCpu); int curIdxOffset = ((T - 1) % 2); int ltrIdx = alphasCpu_a[curIdxOffset][S - 1] > alphasCpu_a[curIdxOffset][S - 2] @@ -244,18 +242,11 @@ void forced_align_impl( } } -template -const auto forced_align_long_impl = - forced_align_impl; - -template -const auto forced_align_int_impl = forced_align_impl; - std::tuple compute( - const Tensor& logProbs, - const Tensor& targets, - const Tensor& inputLengths, - const Tensor& targetLengths, + Tensor logProbs, + Tensor targets, + Tensor inputLengths, + Tensor targetLengths, const int64_t blank) { STD_TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor"); @@ -307,31 +298,17 @@ std::tuple compute( THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] { if (targets.scalar_type() == ScalarType::Long) { - forced_align_long_impl(logProbs, targets, blank, paths); + (forced_align_impl(logProbs, targets, blank, paths)); } else { - forced_align_int_impl(logProbs, targets, blank, paths); - } + (forced_align_impl(logProbs, targets, blank, paths)); + } }), AT_EXPAND(AT_FLOATING_TYPES), ScalarType::Half); - Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index()); return std::make_tuple(pathsCuda, logProbs); } -void boxed_forced_align_gpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); - STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); - std::tuple res = compute( - /*logProbs*/torch::stable::detail::to(stack[0]), - /*targets*/torch::stable::detail::to(stack[1]), - /*logit_lengths*/torch::stable::detail::to(stack[2]), - /*target_lengths*/torch::stable::detail::to(stack[3]), - /*blank*/float(torch::stable::detail::to(stack[4]))); - stack[0] = torch::stable::detail::from(std::get<0>(res)); - stack[1] = torch::stable::detail::from(std::get<1>(res)); -} - STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("forced_align", &boxed_forced_align_gpu); + m.impl("forced_align", TORCH_BOX(&compute)); } } // namespace gpu diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 4e3c8bf8c9..ab527ace4e 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -5,27 +6,15 @@ #include namespace { - using torch::stable::Tensor; -template -using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor; - -// TODO: eliminate accessor(t) in favor of t.accessor -// after Tensor::accessor is supported in stable ABI -template -inline TensorAccessor accessor(Tensor t) { - return TensorAccessor( - reinterpret_cast(t.data_ptr()), t.sizes().data(), t.strides().data()); -} - template void overdrive_cpu_kernel( - TensorAccessor waveform_accessor, - TensorAccessor temp_accessor, - TensorAccessor last_in_accessor, - TensorAccessor last_out_accessor, - TensorAccessor output_waveform_accessor) { + torchaudio::TensorAccessor waveform_accessor, + torchaudio::TensorAccessor temp_accessor, + torchaudio::TensorAccessor last_in_accessor, + torchaudio::TensorAccessor last_out_accessor, + torchaudio::TensorAccessor output_waveform_accessor) { int64_t n_frames = waveform_accessor.size(1); int64_t n_channels = waveform_accessor.size(0); @@ -56,11 +45,11 @@ std::tuple overdrive_core_loop_cpu( "overdrive_cpu", AT_WRAP([&] { overdrive_cpu_kernel( - accessor(waveform), - accessor(temp), - accessor(last_in), - accessor(last_out), - accessor(output_waveform)); + torchaudio::accessor(waveform), + torchaudio::accessor(temp), + torchaudio::accessor(last_in), + torchaudio::accessor(last_out), + torchaudio::accessor(output_waveform)); }), AT_FLOATING_TYPES); return std::make_tuple(last_in, last_out, output_waveform); @@ -70,7 +59,11 @@ std::tuple overdrive_core_loop_cpu( STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( - "_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))"); + "_overdrive_core_loop(Tensor waveform," + "Tensor temp," + "Tensor(a!) last_in," + "Tensor(b!) last_out," + "Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))"); } STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 811833c96f..0449b8afcc 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace torchaudio { namespace rnnt { @@ -12,10 +13,10 @@ using torch::stable::Tensor; // Entry point into RNNT Loss std::tuple compute( - const Tensor& logits, - const Tensor& targets, - const Tensor& logit_lengths, - const Tensor& target_lengths, + Tensor logits, + Tensor targets, + Tensor logit_lengths, + Tensor target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { @@ -109,6 +110,8 @@ std::tuple compute( {DtypeWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Float); + // TODO: use t.mutable_data_ptr<..>() instead of reinterpret_cast + // when stable ABI Tensor supports mutable_data_ptr templates. Workspace workspace( /*options=*/options, /*dtype_data=*/reinterpret_cast(float_workspace.data_ptr()), @@ -116,57 +119,27 @@ std::tuple compute( /*int_data=*/reinterpret_cast(int_workspace.data_ptr()), /*int_size=*/int_workspace.numel()); - switch (logits.scalar_type()) { - case ScalarType::Float: { - Compute( - /*workspace=*/workspace, - /*logits=*/reinterpret_cast(logits.data_ptr()), - /*targets=*/reinterpret_cast(targets.data_ptr()), - /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), - /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), - /*costs=*/reinterpret_cast(costs.data_ptr()), - /*gradients=*/reinterpret_cast(gradients.data_ptr())); - break; - } - case ScalarType::Half: { - Compute( - /*workspace=*/workspace, - /*logits=*/reinterpret_cast(logits.data_ptr()), - /*targets=*/reinterpret_cast(targets.data_ptr()), - /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), - /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), - /*costs=*/reinterpret_cast(costs.data_ptr()), - /*gradients=*/reinterpret_cast(gradients.data_ptr())); - break; - } - default: { - STD_TORCH_CHECK(false, "unreachable"); - } - }; + THO_DISPATCH_V2( + logits.scalar_type(), + "rnnt:compute", + AT_WRAP([&] { + (Compute( + /*workspace=*/workspace, + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr()))); + }), + ScalarType::Float, + ScalarType::Half); return std::make_tuple(costs, gradients); } -void boxed_rnnt_loss( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { - STD_TORCH_CHECK(num_args == 7, "num_args must be 7"); - STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); - std::tuple res = compute( - /*logits*/ torch::stable::detail::to(stack[0]), - /*targets*/ torch::stable::detail::to(stack[1]), - /*logit_lengths*/ torch::stable::detail::to(stack[2]), - /*target_lengths*/ torch::stable::detail::to(stack[3]), - /*blank*/ float(torch::stable::detail::to(stack[4])), - /*clamp*/ torch::stable::detail::to(stack[5]), - /*fused_log_softmax*/ torch::stable::detail::to(stack[6])); - stack[0] = torch::stable::detail::from(std::get<0>(res)); - stack[1] = torch::stable::detail::from(std::get<1>(res)); -} - STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_forward", &boxed_rnnt_loss); + m.impl("rnnt_loss_forward", TORCH_BOX(&compute)); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 77c6a0e268..800f029121 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -148,11 +148,6 @@ std::tuple compute( return std::make_tuple(costs, gradients); } -STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)"); -} - STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { m.impl("rnnt_loss_forward", TORCH_BOX(&compute)); } diff --git a/src/libtorchaudio/stable/TensorAccessor.h b/src/libtorchaudio/stable/TensorAccessor.h deleted file mode 100644 index db786dafa6..0000000000 --- a/src/libtorchaudio/stable/TensorAccessor.h +++ /dev/null @@ -1,275 +0,0 @@ -#pragma once -/* - This header files provides torchaudio::stable::TensorAccessor - templates that are torch::stable::Tensor-compatible analogus of - at::TensorAccessor defined in ATen/core/TensorAccessor.h. - - TODO: remove this header file when torch::stable provides all - features implemented here. -*/ - -// #include - -#include -#include - -namespace torchaudio::stable { - -template -struct DefaultPtrTraits { - typedef T* PtrType; -}; - -#if defined(__CUDACC__) || defined(__HIPCC__) -template -struct RestrictPtrTraits { - typedef T* __restrict__ PtrType; -}; -#endif - -template < - typename T, - size_t N, - template class PtrTraits = DefaultPtrTraits, - typename index_t = int64_t> -class TensorAccessorBase { - public: - typedef typename PtrTraits::PtrType PtrType; - - C10_HOST_DEVICE TensorAccessorBase( - PtrType data_, - const index_t* sizes_, - const index_t* strides_) - : data_(data_) /*, sizes_(sizes_), strides_(strides_)*/ { - // Originally, TensorAccessor is a view of sizes and strides as - // these are ArrayRef instances. Until torch::stable supports - // ArrayRef-like features, we store copies of sizes and strides: - for (size_t i = 0; i < N; ++i) { - this->sizes_[i] = sizes_[i]; - this->strides_[i] = strides_[i]; - } - } - - C10_HOST_DEVICE PtrType data() { - return data_; - } - C10_HOST_DEVICE const PtrType data() const { - return data_; - } - - protected: - PtrType data_; - /* - const index_t* sizes_; - const index_t* strides_; - */ - // NOLINTNEXTLINE(*c-arrays*) - index_t sizes_[N]; - // NOLINTNEXTLINE(*c-arrays*) - index_t strides_[N]; -}; - -template < - typename T, - size_t N, - template class PtrTraits = DefaultPtrTraits, - typename index_t = int64_t> -class TensorAccessor : public TensorAccessorBase { - public: - typedef typename PtrTraits::PtrType PtrType; - - C10_HOST_DEVICE TensorAccessor( - PtrType data_, - const index_t* sizes_, - const index_t* strides_) - : TensorAccessorBase(data_, sizes_, strides_) {} - - C10_HOST_DEVICE TensorAccessor operator[]( - index_t i) { - return TensorAccessor( - this->data_ + this->strides_[0] * i, - this->sizes_ + 1, - this->strides_ + 1); - } - - C10_HOST_DEVICE const TensorAccessor operator[]( - index_t i) const { - return TensorAccessor( - this->data_ + this->strides_[0] * i, - this->sizes_ + 1, - this->strides_ + 1); - } -}; - -template class PtrTraits, typename index_t> -class TensorAccessor - : public TensorAccessorBase { - public: - typedef typename PtrTraits::PtrType PtrType; - - C10_HOST_DEVICE TensorAccessor( - PtrType data_, - const index_t* sizes_, - const index_t* strides_) - : TensorAccessorBase(data_, sizes_, strides_) {} - C10_HOST_DEVICE T& operator[](index_t i) { - // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) - return this->data_[this->strides_[0] * i]; - } - C10_HOST_DEVICE const T& operator[](index_t i) const { - return this->data_[this->strides_[0] * i]; - } -}; - -template < - typename T, - size_t N, - template class PtrTraits = DefaultPtrTraits, - typename index_t = int64_t> -class GenericPackedTensorAccessorBase { - public: - typedef typename PtrTraits::PtrType PtrType; - C10_HOST GenericPackedTensorAccessorBase( - PtrType data_, - const index_t* sizes_, - const index_t* strides_) - : data_(data_) { - std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); - std::copy(strides_, strides_ + N, std::begin(this->strides_)); - } - - template < - typename source_index_t, - class = std::enable_if_t>> - C10_HOST GenericPackedTensorAccessorBase( - PtrType data_, - const source_index_t* sizes_, - const source_index_t* strides_) - : data_(data_) { - for (size_t i = 0; i < N; ++i) { - this->sizes_[i] = sizes_[i]; - this->strides_[i] = strides_[i]; - } - } - - C10_HOST_DEVICE PtrType data() { - return data_; - } - C10_HOST_DEVICE const PtrType data() const { - return data_; - } - - protected: - PtrType data_; - // NOLINTNEXTLINE(*c-arrays*) - index_t sizes_[N]; - // NOLINTNEXTLINE(*c-arrays*) - index_t strides_[N]; - C10_HOST void bounds_check_(index_t i) const { - STD_TORCH_CHECK( - 0 <= i && i < index_t{N}, - "Index ", - i, - " is not within bounds of a tensor of dimension ", - N); - } -}; - -template < - typename T, - size_t N, - template class PtrTraits = DefaultPtrTraits, - typename index_t = int64_t> -class GenericPackedTensorAccessor - : public GenericPackedTensorAccessorBase { - public: - typedef typename PtrTraits::PtrType PtrType; - - C10_HOST GenericPackedTensorAccessor( - PtrType data_, - const index_t* sizes_, - const index_t* strides_) - : GenericPackedTensorAccessorBase( - data_, - sizes_, - strides_) {} - - // if index_t is not int64_t, we want to have an int64_t constructor - template < - typename source_index_t, - class = std::enable_if_t>> - C10_HOST GenericPackedTensorAccessor( - PtrType data_, - const source_index_t* sizes_, - const source_index_t* strides_) - : GenericPackedTensorAccessorBase( - data_, - sizes_, - strides_) {} - - C10_DEVICE TensorAccessor operator[]( - index_t i) { - index_t* new_sizes = this->sizes_ + 1; - index_t* new_strides = this->strides_ + 1; - return TensorAccessor( - this->data_ + this->strides_[0] * i, new_sizes, new_strides); - } - - C10_DEVICE const TensorAccessor operator[]( - index_t i) const { - const index_t* new_sizes = this->sizes_ + 1; - const index_t* new_strides = this->strides_ + 1; - return TensorAccessor( - this->data_ + this->strides_[0] * i, new_sizes, new_strides); - } -}; - -template class PtrTraits, typename index_t> -class GenericPackedTensorAccessor - : public GenericPackedTensorAccessorBase { - public: - typedef typename PtrTraits::PtrType PtrType; - C10_HOST GenericPackedTensorAccessor( - PtrType data_, - const index_t* sizes_, - const index_t* strides_) - : GenericPackedTensorAccessorBase( - data_, - sizes_, - strides_) {} - - template < - typename source_index_t, - class = std::enable_if_t>> - C10_HOST GenericPackedTensorAccessor( - PtrType data_, - const source_index_t* sizes_, - const source_index_t* strides_) - : GenericPackedTensorAccessorBase( - data_, - sizes_, - strides_) {} - - C10_DEVICE T& operator[](index_t i) { - return this->data_[this->strides_[0] * i]; - } - C10_DEVICE const T& operator[](index_t i) const { - return this->data_[this->strides_[0] * i]; - } -}; - -template < - typename T, - size_t N, - template class PtrTraits = DefaultPtrTraits> -using PackedTensorAccessor32 = - GenericPackedTensorAccessor; - -template < - typename T, - size_t N, - template class PtrTraits = DefaultPtrTraits> -using PackedTensorAccessor64 = - GenericPackedTensorAccessor; - -} // namespace torchaudio::stable diff --git a/src/libtorchaudio/stable/ops.h b/src/libtorchaudio/stable/ops.h index 529dbccd3b..fa5a89345d 100644 --- a/src/libtorchaudio/stable/ops.h +++ b/src/libtorchaudio/stable/ops.h @@ -10,7 +10,6 @@ */ #include -#include #include #ifdef USE_CUDA @@ -59,87 +58,6 @@ const T* const_data_ptr(const Tensor& t) { return reinterpret_cast(data_ptr); } -// TODO: When accessor is implemented in torch::stable, eliminate -// accessor template below. -template -torchaudio::stable::TensorAccessor accessor(const Tensor& t) { - static_assert( - N > 0, - "accessor is used for indexing tensor, for scalars use *data_ptr()"); - STD_TORCH_CHECK( - t.dim() == N, - "TensorAccessor expected ", - N, - " dims but tensor has ", - t.dim()); - T* ptr = nullptr; - if constexpr (std::is_const_v) { - ptr = const_data_ptr(t); - } else { - ptr = mutable_data_ptr(t); - } - auto sizes_ = sizes(t); - auto strides_ = strides(t); - return torchaudio::stable::TensorAccessor( - ptr, sizes_.data(), strides_.data()); -} - -// TODO: move to TensorAccessor.h? -template < - typename T, - size_t N, - template - class PtrTraits = torchaudio::stable::DefaultPtrTraits, - typename index_t = int64_t> -torchaudio::stable::GenericPackedTensorAccessor -generic_packed_accessor(const Tensor& t) { - static_assert( - N > 0, - "accessor is used for indexing tensor, for scalars use *data_ptr()"); - STD_TORCH_CHECK( - t.dim() == N, - "TensorAccessor expected ", - N, - " dims but tensor has ", - t.dim()); - T* ptr = nullptr; - if constexpr (std::is_const_v) { - ptr = const_data_ptr(t); - } else { - ptr = mutable_data_ptr(t); - } - auto sizes_ = sizes(t); - auto strides_ = strides(t); - return torchaudio::stable:: - GenericPackedTensorAccessor( - static_cast::PtrType>(ptr), - sizes_.data(), - strides_.data()); -} - -template < - typename T, - size_t N, - template - class PtrTraits = torchaudio::stable::DefaultPtrTraits> -torchaudio::stable::PackedTensorAccessor32 packed_accessor32( - const Tensor& t) { - STD_TORCH_CHECK( - t.numel() <= static_cast(std::numeric_limits::max()), - "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); - return generic_packed_accessor(t); -} - -template < - typename T, - size_t N, - template - class PtrTraits = torchaudio::stable::DefaultPtrTraits> -torchaudio::stable::PackedTensorAccessor64 packed_accessor64( - const Tensor& t) { - return generic_packed_accessor(); -} - // TODO: When cpu is implemented in torch::stable, eliminate // cpu function below. inline Tensor cpu(const Tensor& self) { diff --git a/src/libtorchaudio/utils.h b/src/libtorchaudio/utils.h index 4693f7e91c..a700e30efa 100644 --- a/src/libtorchaudio/utils.h +++ b/src/libtorchaudio/utils.h @@ -1,5 +1,7 @@ #pragma once +#include + // TODO: replace the include libtorchaudio/stable/ops.h with // torch/stable/ops.h when torch::stable provides all required // features (torch::stable::item or similar): @@ -17,4 +19,35 @@ T max(const torch::stable::Tensor& t) { bool is_align_available(); std::optional cuda_version(); +template +using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor; + +// TODO: eliminate accessor(t) in favor of t.accessor +// after Tensor::accessor is supported in stable ABI +template +inline TensorAccessor accessor(Tensor t) { + return TensorAccessor( + reinterpret_cast(t.data_ptr()), t.sizes().data(), t.strides().data()); +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +template +using PackedTensorAccessor32 = + torch::headeronly::HeaderOnlyGenericPackedTensorAccessor< + T, + N, + torch::headeronly::RestrictPtrTraits, + int32_t>; + +// TODO: eliminate accessor(t) in favor of t.accessor +// after Tensor::accessor is supported in stable ABI +template +inline PackedTensorAccessor32 packed_accessor32(Tensor t) { + return PackedTensorAccessor32( + static_cast::PtrType>(t.data_ptr()), + t.sizes().data(), + t.strides().data()); +} +#endif + } // namespace torchaudio