@@ -39,100 +39,121 @@ struct ET_EXPERIMENTAL RawAudio {
3939 * Mel spectrograms are typically represented as floating point values. For raw
4040 * or quantized audio, uint8_t may be used instead.
4141 */
42- class ET_EXPERIMENTAL Audio {
42+ class ET_EXPERIMENTAL Audio final {
4343 public:
4444 // Default constructor
45- Audio () : batch_size (0 ), n_bins (0 ), n_frames (0 ) {}
45+ Audio () : batch_size_ (0 ), n_bins_ (0 ), n_frames_ (0 ) {}
4646
4747 // Constructor for uint8_t data
4848 Audio (
49- std::vector<uint8_t >&& data_,
50- int32_t batch_size_,
51- int32_t n_bins_,
52- int32_t n_frames_)
53- : data(std::move(data_)),
54- batch_size (batch_size_),
55- n_bins(n_bins_),
56- n_frames(n_frames_) {}
49+ std::vector<uint8_t >&& data,
50+ int32_t batch_size,
51+ int32_t n_bins,
52+ int32_t n_frames)
53+ : data_(std::move(data)),
54+ batch_size_ (batch_size),
55+ n_bins_(n_bins),
56+ n_frames_(n_frames) {
57+ ET_CHECK_MSG (
58+ data_.index () == 0 &&
59+ std::get<std::vector<uint8_t >>(data_).size () ==
60+ static_cast <size_t >(batch_size * n_bins * n_frames),
61+ " data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)" ,
62+ std::get<std::vector<uint8_t >>(data_).size (),
63+ batch_size * n_bins * n_frames);
64+ }
5765
5866 // Constructor for float data
5967 Audio (
60- std::vector<float >&& data_,
61- int32_t batch_size_,
62- int32_t n_bins_,
63- int32_t n_frames_)
64- : data(std::move(data_)),
65- batch_size(batch_size_),
66- n_bins(n_bins_),
67- n_frames(n_frames_) {}
68+ std::vector<float >&& data,
69+ int32_t batch_size,
70+ int32_t n_bins,
71+ int32_t n_frames)
72+ : data_(std::move(data)),
73+ batch_size_(batch_size),
74+ n_bins_(n_bins),
75+ n_frames_(n_frames) {
76+ ET_CHECK_MSG (
77+ data_.index () == 1 &&
78+ std::get<std::vector<float >>(data_).size () ==
79+ static_cast <size_t >(batch_size * n_bins * n_frames),
80+ " data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)" ,
81+ std::get<std::vector<float >>(data_).size (),
82+ batch_size * n_bins * n_frames);
83+ }
6884
6985 // Type checkers
7086 bool is_uint8 () const {
71- return std::holds_alternative<std::vector<uint8_t >>(data );
87+ return std::holds_alternative<std::vector<uint8_t >>(data_ );
7288 }
7389
7490 bool is_float () const {
75- return std::holds_alternative<std::vector<float >>(data );
91+ return std::holds_alternative<std::vector<float >>(data_ );
7692 }
7793
7894 // Data access
7995 const std::vector<uint8_t >& get_uint8_data () const & {
80- return std::get<std::vector<uint8_t >>(data );
96+ return std::get<std::vector<uint8_t >>(data_ );
8197 }
8298
8399 std::vector<uint8_t >& get_uint8_data () & {
84- return std::get<std::vector<uint8_t >>(data );
100+ return std::get<std::vector<uint8_t >>(data_ );
85101 }
86102
87103 const std::vector<float >& get_float_data () const & {
88- return std::get<std::vector<float >>(data );
104+ return std::get<std::vector<float >>(data_ );
89105 }
90106
91107 std::vector<float >& get_float_data () & {
92- return std::get<std::vector<float >>(data );
108+ return std::get<std::vector<float >>(data_ );
93109 }
94110
95111 int32_t get_batch_size () const {
96- return batch_size ;
112+ return batch_size_ ;
97113 }
98114 int32_t get_n_bins () const {
99- return n_bins ;
115+ return n_bins_ ;
100116 }
101117 int32_t get_n_frames () const {
102- return n_frames ;
118+ return n_frames_ ;
103119 }
104120 /* *
105121 * Convert the audio data to a TensorPtr, with optional batch dimension.
106122 * The tensor will have shape (batch_size, n_bins, n_frames) or (1,
107123 * batch_size, n_bins, n_frames) if with_batch is true.
108124 */
109- executorch::runtime::Result<executorch::extension::TensorPtr> toTensor ()
110- const {
111- std::vector<executorch::aten::SizesType> sizes = {
112- get_batch_size (), get_n_bins (), get_n_frames ()};
113- if (is_float ()) {
114- return executorch::extension::from_blob (
115- const_cast <float *>(get_float_data ().data ()),
116- sizes,
117- ::executorch::aten::ScalarType::Float);
118- } else if (is_uint8 ()) {
119- return executorch::extension::from_blob (
120- const_cast <uint8_t *>(get_uint8_data ().data ()),
121- sizes,
122- ::executorch::aten::ScalarType::Byte);
125+ executorch::runtime::Result<executorch::extension::TensorPtr> toTensor (
126+ bool with_batch = false ) {
127+ const {
128+ std::vector<executorch::aten::SizesType> sizes = {
129+ get_batch_size (), get_n_bins (), get_n_frames ()};
130+ if (with_batch) {
131+ sizes.insert (sizes.begin (), 1 );
132+ }
133+ if (is_float ()) {
134+ return executorch::extension::from_blob (
135+ const_cast <float *>(get_float_data ().data ()),
136+ sizes,
137+ ::executorch::aten::ScalarType::Float);
138+ } else if (is_uint8 ()) {
139+ return executorch::extension::from_blob (
140+ const_cast <uint8_t *>(get_uint8_data ().data ()),
141+ sizes,
142+ ::executorch::aten::ScalarType::Byte);
143+ }
144+ ET_LOG (
145+ Error,
146+ " Shouldn't reach here, audio data is not initialized with uint8_t or float vector." );
147+ return ::executorch::runtime::Error::NotSupported;
123148 }
124- ET_LOG (
125- Error, " Audio data is not initialized with uint8_t or float vector." );
126- return ::executorch::runtime::Error::NotSupported;
127- }
128149
129- private:
130- // Members
131- std::variant<std::vector<uint8_t >, std::vector<float >> data ;
132- int32_t batch_size ;
133- int32_t n_bins ;
134- int32_t n_frames ;
135- };
150+ private:
151+ // Members
152+ std::variant<std::vector<uint8_t >, std::vector<float >> data_ ;
153+ int32_t batch_size_ ;
154+ int32_t n_bins_ ;
155+ int32_t n_frames_ ;
156+ };
136157
137158} // namespace llm
138159} // namespace extension
0 commit comments