Skip to content

Commit 821d085

Browse files
committed
Address comments
1 parent a246998 commit 821d085

File tree

2 files changed

+78
-53
lines changed

2 files changed

+78
-53
lines changed

extension/llm/runner/audio.h

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -39,100 +39,121 @@ struct ET_EXPERIMENTAL RawAudio {
3939
* Mel spectrograms are typically represented as floating point values. For raw
4040
* or quantized audio, uint8_t may be used instead.
4141
*/
42-
class ET_EXPERIMENTAL Audio {
42+
class ET_EXPERIMENTAL Audio final {
4343
public:
4444
// Default constructor
45-
Audio() : batch_size(0), n_bins(0), n_frames(0) {}
45+
Audio() : batch_size_(0), n_bins_(0), n_frames_(0) {}
4646

4747
// Constructor for uint8_t data
4848
Audio(
49-
std::vector<uint8_t>&& data_,
50-
int32_t batch_size_,
51-
int32_t n_bins_,
52-
int32_t n_frames_)
53-
: data(std::move(data_)),
54-
batch_size(batch_size_),
55-
n_bins(n_bins_),
56-
n_frames(n_frames_) {}
49+
std::vector<uint8_t>&& data,
50+
int32_t batch_size,
51+
int32_t n_bins,
52+
int32_t n_frames)
53+
: data_(std::move(data)),
54+
batch_size_(batch_size),
55+
n_bins_(n_bins),
56+
n_frames_(n_frames) {
57+
ET_CHECK_MSG(
58+
data_.index() == 0 &&
59+
std::get<std::vector<uint8_t>>(data_).size() ==
60+
static_cast<size_t>(batch_size * n_bins * n_frames),
61+
"data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)",
62+
std::get<std::vector<uint8_t>>(data_).size(),
63+
batch_size * n_bins * n_frames);
64+
}
5765

5866
// Constructor for float data
5967
Audio(
60-
std::vector<float>&& data_,
61-
int32_t batch_size_,
62-
int32_t n_bins_,
63-
int32_t n_frames_)
64-
: data(std::move(data_)),
65-
batch_size(batch_size_),
66-
n_bins(n_bins_),
67-
n_frames(n_frames_) {}
68+
std::vector<float>&& data,
69+
int32_t batch_size,
70+
int32_t n_bins,
71+
int32_t n_frames)
72+
: data_(std::move(data)),
73+
batch_size_(batch_size),
74+
n_bins_(n_bins),
75+
n_frames_(n_frames) {
76+
ET_CHECK_MSG(
77+
data_.index() == 1 &&
78+
std::get<std::vector<float>>(data_).size() ==
79+
static_cast<size_t>(batch_size * n_bins * n_frames),
80+
"data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)",
81+
std::get<std::vector<float>>(data_).size(),
82+
batch_size * n_bins * n_frames);
83+
}
6884

6985
// Type checkers
7086
bool is_uint8() const {
71-
return std::holds_alternative<std::vector<uint8_t>>(data);
87+
return std::holds_alternative<std::vector<uint8_t>>(data_);
7288
}
7389

7490
bool is_float() const {
75-
return std::holds_alternative<std::vector<float>>(data);
91+
return std::holds_alternative<std::vector<float>>(data_);
7692
}
7793

7894
// Data access
7995
const std::vector<uint8_t>& get_uint8_data() const& {
80-
return std::get<std::vector<uint8_t>>(data);
96+
return std::get<std::vector<uint8_t>>(data_);
8197
}
8298

8399
std::vector<uint8_t>& get_uint8_data() & {
84-
return std::get<std::vector<uint8_t>>(data);
100+
return std::get<std::vector<uint8_t>>(data_);
85101
}
86102

87103
const std::vector<float>& get_float_data() const& {
88-
return std::get<std::vector<float>>(data);
104+
return std::get<std::vector<float>>(data_);
89105
}
90106

91107
std::vector<float>& get_float_data() & {
92-
return std::get<std::vector<float>>(data);
108+
return std::get<std::vector<float>>(data_);
93109
}
94110

95111
int32_t get_batch_size() const {
96-
return batch_size;
112+
return batch_size_;
97113
}
98114
int32_t get_n_bins() const {
99-
return n_bins;
115+
return n_bins_;
100116
}
101117
int32_t get_n_frames() const {
102-
return n_frames;
118+
return n_frames_;
103119
}
104120
/**
105121
* Convert the audio data to a TensorPtr, with optional batch dimension.
106122
* The tensor will have shape (batch_size, n_bins, n_frames) or (1,
107123
* batch_size, n_bins, n_frames) if with_batch is true.
108124
*/
109-
executorch::runtime::Result<executorch::extension::TensorPtr> toTensor()
110-
const {
111-
std::vector<executorch::aten::SizesType> sizes = {
112-
get_batch_size(), get_n_bins(), get_n_frames()};
113-
if (is_float()) {
114-
return executorch::extension::from_blob(
115-
const_cast<float*>(get_float_data().data()),
116-
sizes,
117-
::executorch::aten::ScalarType::Float);
118-
} else if (is_uint8()) {
119-
return executorch::extension::from_blob(
120-
const_cast<uint8_t*>(get_uint8_data().data()),
121-
sizes,
122-
::executorch::aten::ScalarType::Byte);
125+
executorch::runtime::Result<executorch::extension::TensorPtr> toTensor(
126+
bool with_batch = false) {
127+
const {
128+
std::vector<executorch::aten::SizesType> sizes = {
129+
get_batch_size(), get_n_bins(), get_n_frames()};
130+
if (with_batch) {
131+
sizes.insert(sizes.begin(), 1);
132+
}
133+
if (is_float()) {
134+
return executorch::extension::from_blob(
135+
const_cast<float*>(get_float_data().data()),
136+
sizes,
137+
::executorch::aten::ScalarType::Float);
138+
} else if (is_uint8()) {
139+
return executorch::extension::from_blob(
140+
const_cast<uint8_t*>(get_uint8_data().data()),
141+
sizes,
142+
::executorch::aten::ScalarType::Byte);
143+
}
144+
ET_LOG(
145+
Error,
146+
"Shouldn't reach here, audio data is not initialized with uint8_t or float vector.");
147+
return ::executorch::runtime::Error::NotSupported;
123148
}
124-
ET_LOG(
125-
Error, "Audio data is not initialized with uint8_t or float vector.");
126-
return ::executorch::runtime::Error::NotSupported;
127-
}
128149

129-
private:
130-
// Members
131-
std::variant<std::vector<uint8_t>, std::vector<float>> data;
132-
int32_t batch_size;
133-
int32_t n_bins;
134-
int32_t n_frames;
135-
};
150+
private:
151+
// Members
152+
std::variant<std::vector<uint8_t>, std::vector<float>> data_;
153+
int32_t batch_size_;
154+
int32_t n_bins_;
155+
int32_t n_frames_;
156+
};
136157

137158
} // namespace llm
138159
} // namespace extension

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ Result<uint64_t> MultimodalPrefiller::prefill(
9696
// Use Audio::toTensor() for tensor creation
9797
auto audio_tensor =
9898
ET_UNWRAP(audio.toTensor(), "Failed to convert audio to tensor");
99-
99+
ET_LOG(
100+
Info,
101+
"Audio tensor dim: %zu, dtype: %s",
102+
audio_tensor->dim(),
103+
::executorch::runtime::toString(audio_tensor->scalar_type()));
100104
// Run audio encoder
101105
auto audio_encoder_result =
102106
module_->execute(kAudioEncoderMethod, audio_tensor);

0 commit comments

Comments
 (0)