Skip to content

Commit ee07051

Browse files
committed
Collect a torch::stable wishlist in src/libtorchaudio/stable
1 parent 48081cf commit ee07051

File tree

4 files changed

+608
-0
lines changed

4 files changed

+608
-0
lines changed

src/libtorchaudio/stable/Device.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
/*
4+
This header files provides torchaudio::stable::Device struct that is
5+
torch::stable::Tensor-compatible analogus of c10::Device defined
6+
c10/core/Device.h.
7+
8+
TODO: remove this header file when torch::stable provides all
9+
features implemented here.
10+
*/
11+
12+
#include <torch/csrc/stable/accelerator.h>
13+
14+
namespace torchaudio::stable {
15+
16+
using DeviceType = int32_t;
17+
using torch::stable::accelerator::DeviceIndex;
18+
19+
struct Device {
20+
21+
Device(DeviceType type, DeviceIndex index = -1)
22+
: type_(type), index_(index) {
23+
// TODO: validate();
24+
}
25+
26+
/// Returns the type of device this is.
27+
DeviceType type() const noexcept {
28+
return type_;
29+
}
30+
31+
/// Returns the optional index.
32+
DeviceIndex index() const noexcept {
33+
return index_;
34+
}
35+
36+
private:
37+
DeviceType type_;
38+
DeviceIndex index_ = -1;
39+
};
40+
41+
// A convinience function, not a part of torch::stable
42+
inline Device cpu_device() {
43+
Device d(aoti_torch_device_type_cpu(), 0);
44+
return d;
45+
}
46+
47+
} // namespace torchaudio::stable
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#pragma once
2+
/*
3+
This header files provides torchaudio::stable::TensorAccessor
4+
templates that are torch::stable::Tensor-compatible analogus of
5+
at::TensorAccessor defined in ATen/core/TensorAccessor.h.
6+
7+
TODO: remove this header file when torch::stable provides all
8+
features implemented here.
9+
*/
10+
11+
// #include <libtorchaudio/stable/Device.h>
12+
13+
#include <torch/headeronly/macros/Macros.h>
14+
#include <type_traits>
15+
16+
namespace torchaudio::stable {
17+
18+
template <typename T>
19+
struct DefaultPtrTraits {
20+
typedef T* PtrType;
21+
};
22+
23+
#if defined(__CUDACC__) || defined(__HIPCC__)
24+
template <typename T>
25+
struct RestrictPtrTraits {
26+
typedef T* __restrict__ PtrType;
27+
};
28+
#endif
29+
30+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
31+
class TensorAccessorBase {
32+
public:
33+
typedef typename PtrTraits<T>::PtrType PtrType;
34+
35+
C10_HOST_DEVICE TensorAccessorBase(
36+
PtrType data_,
37+
const index_t* sizes_,
38+
const index_t* strides_)
39+
: data_(data_) /*, sizes_(sizes_), strides_(strides_)*/ {
40+
// Originally, TensorAccessor is a view of sizes and strides as
41+
// these are ArrayRef instances. Until torch::stable supports
42+
// ArrayRef-like features, we store copies of sizes and strides:
43+
for (auto i=0; i < N; ++i) {
44+
this->sizes_[i] = sizes_[i];
45+
this->strides_[i] = strides_[i];
46+
}
47+
}
48+
49+
C10_HOST_DEVICE PtrType data() {
50+
return data_;
51+
}
52+
C10_HOST_DEVICE const PtrType data() const {
53+
return data_;
54+
}
55+
protected:
56+
PtrType data_;
57+
/*
58+
const index_t* sizes_;
59+
const index_t* strides_;
60+
*/
61+
// NOLINTNEXTLINE(*c-arrays*)
62+
index_t sizes_[N];
63+
// NOLINTNEXTLINE(*c-arrays*)
64+
index_t strides_[N];
65+
};
66+
67+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
68+
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
69+
public:
70+
typedef typename PtrTraits<T>::PtrType PtrType;
71+
72+
C10_HOST_DEVICE TensorAccessor(
73+
PtrType data_,
74+
const index_t* sizes_,
75+
const index_t* strides_)
76+
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
77+
78+
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
79+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
80+
}
81+
82+
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
83+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
84+
}
85+
};
86+
87+
template<typename T, template <typename U> class PtrTraits, typename index_t>
88+
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
89+
public:
90+
typedef typename PtrTraits<T>::PtrType PtrType;
91+
92+
C10_HOST_DEVICE TensorAccessor(
93+
PtrType data_,
94+
const index_t* sizes_,
95+
const index_t* strides_)
96+
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
97+
C10_HOST_DEVICE T & operator[](index_t i) {
98+
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
99+
return this->data_[this->strides_[0]*i];
100+
}
101+
C10_HOST_DEVICE const T & operator[](index_t i) const {
102+
return this->data_[this->strides_[0]*i];
103+
}
104+
};
105+
106+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
107+
class GenericPackedTensorAccessorBase {
108+
public:
109+
typedef typename PtrTraits<T>::PtrType PtrType;
110+
C10_HOST GenericPackedTensorAccessorBase(
111+
PtrType data_,
112+
const index_t* sizes_,
113+
const index_t* strides_)
114+
: data_(data_) {
115+
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
116+
std::copy(strides_, strides_ + N, std::begin(this->strides_));
117+
}
118+
119+
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
120+
C10_HOST GenericPackedTensorAccessorBase(
121+
PtrType data_,
122+
const source_index_t* sizes_,
123+
const source_index_t* strides_)
124+
: data_(data_) {
125+
for (auto i=0; i < N; ++i) {
126+
this->sizes_[i] = sizes_[i];
127+
this->strides_[i] = strides_[i];
128+
}
129+
}
130+
131+
C10_HOST_DEVICE PtrType data() {
132+
return data_;
133+
}
134+
C10_HOST_DEVICE const PtrType data() const {
135+
return data_;
136+
}
137+
protected:
138+
PtrType data_;
139+
// NOLINTNEXTLINE(*c-arrays*)
140+
index_t sizes_[N];
141+
// NOLINTNEXTLINE(*c-arrays*)
142+
index_t strides_[N];
143+
C10_HOST void bounds_check_(index_t i) const {
144+
STD_TORCH_CHECK(
145+
0 <= i && i < index_t{N},
146+
"Index ",
147+
i,
148+
" is not within bounds of a tensor of dimension ",
149+
N);
150+
}
151+
};
152+
153+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
154+
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
155+
public:
156+
typedef typename PtrTraits<T>::PtrType PtrType;
157+
158+
C10_HOST GenericPackedTensorAccessor(
159+
PtrType data_,
160+
const index_t* sizes_,
161+
const index_t* strides_)
162+
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
163+
164+
// if index_t is not int64_t, we want to have an int64_t constructor
165+
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
166+
C10_HOST GenericPackedTensorAccessor(
167+
PtrType data_,
168+
const source_index_t* sizes_,
169+
const source_index_t* strides_)
170+
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
171+
172+
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
173+
index_t* new_sizes = this->sizes_ + 1;
174+
index_t* new_strides = this->strides_ + 1;
175+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
176+
}
177+
178+
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
179+
const index_t* new_sizes = this->sizes_ + 1;
180+
const index_t* new_strides = this->strides_ + 1;
181+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
182+
}
183+
};
184+
185+
template<typename T, template <typename U> class PtrTraits, typename index_t>
186+
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
187+
public:
188+
typedef typename PtrTraits<T>::PtrType PtrType;
189+
C10_HOST GenericPackedTensorAccessor(
190+
PtrType data_,
191+
const index_t* sizes_,
192+
const index_t* strides_)
193+
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
194+
195+
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
196+
C10_HOST GenericPackedTensorAccessor(
197+
PtrType data_,
198+
const source_index_t* sizes_,
199+
const source_index_t* strides_)
200+
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
201+
202+
C10_DEVICE T & operator[](index_t i) {
203+
return this->data_[this->strides_[0] * i];
204+
}
205+
C10_DEVICE const T& operator[](index_t i) const {
206+
return this->data_[this->strides_[0]*i];
207+
}
208+
209+
};
210+
211+
template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
212+
using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;
213+
214+
template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
215+
using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;
216+
217+
} // namespace torchaudio::stable
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#pragma once
2+
/*
3+
This header files provides CPP macros
4+
5+
STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...)
6+
7+
that are torch::stable::Tensor-compatible analogous of
8+
the following macros:
9+
10+
AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...)
11+
12+
respectively.
13+
14+
TODO: remove this header file when torch::stable provides all
15+
features implemented here.
16+
*/
17+
18+
#include <torch/headeronly/util/Exception.h>
19+
#include <torch/headeronly/core/ScalarType.h>
20+
21+
namespace torchaudio::stable {
22+
23+
using torch::headeronly::ScalarType;
24+
25+
namespace impl {
26+
27+
template <ScalarType N>
28+
struct ScalarTypeToCPPType;
29+
30+
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
31+
template <> \
32+
struct ScalarTypeToCPPType<ScalarType::scalar_type> { \
33+
using type = cpp_type; \
34+
};
35+
36+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
37+
38+
#undef SPECIALIZE_ScalarTypeToCPPType
39+
40+
template <ScalarType N>
41+
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
42+
43+
} // namespace impl
44+
45+
} // namespace torchaudio::stable
46+
47+
#define STABLE_DISPATCH_CASE(enum_type, ...) \
48+
case enum_type: { \
49+
using scalar_t [[maybe_unused]] = torchaudio::stable::impl::ScalarTypeToCPPTypeT<enum_type>; \
50+
return __VA_ARGS__(); \
51+
}
52+
53+
#define STABLE_DISPATCH_SWITCH(TYPE, NAME, ...) \
54+
[&] { \
55+
const auto& the_type = TYPE; \
56+
constexpr const char* at_dispatch_name = NAME; \
57+
switch (the_type) { \
58+
__VA_ARGS__ \
59+
default: \
60+
STD_TORCH_CHECK( \
61+
false, \
62+
'"', \
63+
at_dispatch_name, \
64+
"\" not implemented for '", \
65+
toString(the_type), \
66+
"'"); \
67+
} \
68+
}()
69+
70+
#define STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
71+
STABLE_DISPATCH_CASE(ScalarType::Double, __VA_ARGS__) \
72+
STABLE_DISPATCH_CASE(ScalarType::Float, __VA_ARGS__) \
73+
STABLE_DISPATCH_CASE(ScalarType::Half, __VA_ARGS__)
74+
75+
#define STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
76+
STABLE_DISPATCH_SWITCH( \
77+
TYPE, NAME, STABLE_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))

0 commit comments

Comments
 (0)