2222#include < memory>
2323#include < sstream>
2424
25- namespace torch {
26- namespace executor {
25+ using executorch::aten::ScalarType;
26+ using executorch::aten::SizesType;
27+ using executorch::aten::Tensor;
28+ using executorch::extension::from_blob;
29+ using executorch::extension::Module;
30+ using executorch::extension::TensorPtr;
31+ using executorch::extension::llm::BPETokenizer;
32+ using executorch::extension::llm::Sampler;
33+ using executorch::extension::llm::time_in_ms;
34+ using executorch::runtime::Error;
35+ using executorch::runtime::EValue;
36+ using executorch::runtime::MethodMeta;
37+ using executorch::runtime::Result;
38+ using executorch::runtime::TensorInfo;
39+
40+ // TODO: Remove this usage of an internal-only function.
41+ using executorch::runtime::internal::set_tensor_data;
42+
43+ namespace example {
2744
2845namespace {
29- using namespace executorch ::extension;
3046static constexpr auto kTopp = 0 .9f ;
3147void printReport (const Runner::Stats& stats);
3248std::string statsToJsonString (const Runner::Stats& stats);
@@ -57,7 +73,7 @@ Error Runner::load() {
5773 if (is_loaded ()) {
5874 return Error::Ok;
5975 }
60- stats_.model_load_start_ms = util:: time_in_ms ();
76+ stats_.model_load_start_ms = time_in_ms ();
6177 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
6278
6379 // Read out metadata from the model
@@ -97,7 +113,7 @@ Error Runner::load() {
97113 temperature_,
98114 kTopp ,
99115 static_cast <unsigned long long >(std::time (nullptr )));
100- stats_.model_load_end_ms = util:: time_in_ms ();
116+ stats_.model_load_end_ms = time_in_ms ();
101117
102118 return Error::Ok;
103119}
@@ -125,7 +141,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
125141}
126142
127143template <typename T>
128- int32_t Runner::logitsToToken (const exec_aten:: Tensor& logits_tensor) {
144+ int32_t Runner::logitsToToken (const Tensor& logits_tensor) {
129145 T* logits = logits_tensor.mutable_data_ptr <T>();
130146
131147 // Since the logits are for all tokens, get the last token probabilities
@@ -135,7 +151,7 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
135151
136152// Given an input token. Set up the inputs for the model and execute a single
137153// step. Returning the logits tensor.
138- Result<exec_aten:: Tensor> Runner::run_model_step (
154+ Result<Tensor> Runner::run_model_step (
139155 int64_t input_token,
140156 TensorPtr& token,
141157 TensorPtr& start_pos,
@@ -167,7 +183,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
167183 char * new_inp_addr = io_mem_mgr_.update_k_caches_read (j, el_size);
168184 // inputs
169185 ET_CHECK_MSG (
170- internal:: set_tensor_data (
186+ set_tensor_data (
171187 *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes ()) == Error::Ok,
172188 " Failed to set input tensor when updating k_cache" );
173189 }
@@ -177,13 +193,13 @@ Result<exec_aten::Tensor> Runner::run_model_step(
177193 char * new_inp_addr = io_mem_mgr_.update_v_caches_read (v_idx, v_offset);
178194
179195 ET_CHECK_MSG (
180- internal:: set_tensor_data (
196+ set_tensor_data (
181197 *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes ()) == Error::Ok,
182198 " Failed to set input tensor when updating v_cache" );
183199 // outputs
184200 char * new_out_addr = io_mem_mgr_.update_v_caches_write (v_idx, v_offset);
185201 ET_CHECK_MSG (
186- internal:: set_tensor_data (
202+ set_tensor_data (
187203 *kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes ()) == Error::Ok,
188204 " Failed to set output tensor when updating v_cache" );
189205 ET_CHECK_MSG (
@@ -210,7 +226,7 @@ Error Runner::generate(
210226
211227 // First token time only measures the time it takes to encode the prompt and
212228 // return a response token.
213- stats_.inference_start_ms = util:: time_in_ms ();
229+ stats_.inference_start_ms = time_in_ms ();
214230 shouldStop_ = false ;
215231
216232 // Set the sequence length to the max seq length if not provided
@@ -235,21 +251,21 @@ Error Runner::generate(
235251 " Sequence length exceeded - please increase the seq_len value passed to generate()" );
236252
237253 int32_t pos = 0 , prev_token, cur_token = prompt_tokens[0 ];
238- std::vector<exec_aten:: SizesType> token_shape = {1 , 1 };
254+ std::vector<SizesType> token_shape = {1 , 1 };
239255
240256 io_mem_mgr_.get_input_token_ptr ()[0 ] = 0 ;
241- std::vector<exec_aten:: SizesType> start_pos_shape = {1 , 1 };
257+ std::vector<SizesType> start_pos_shape = {1 , 1 };
242258
243259 float * atten_mask_ptr =
244260 reinterpret_cast <float *>(io_mem_mgr_.get_atten_mask_ptr ());
245261 std::fill (atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255 );
246262 atten_mask_ptr[max_seq_len_ - 1 ] = 0 ;
247263
248- std::vector<exec_aten:: SizesType> atten_mask_shape = {1 , max_seq_len_};
264+ std::vector<SizesType> atten_mask_shape = {1 , max_seq_len_};
249265
250- std::vector<exec_aten:: SizesType> logits_data_shape = {1 , vocab_size_};
266+ std::vector<SizesType> logits_data_shape = {1 , vocab_size_};
251267
252- std::vector<exec_aten:: SizesType> hidden_states_data_shape = {1 , 1 , dim_};
268+ std::vector<SizesType> hidden_states_data_shape = {1 , 1 , dim_};
253269
254270 // initialize tensor wrappers
255271 auto token = from_blob (
@@ -274,7 +290,7 @@ Error Runner::generate(
274290 method_meta->input_tensor_meta (input_index);
275291
276292 auto tensor_shape = tensor_meta->sizes ();
277- std::vector<exec_aten:: SizesType> sizes (
293+ std::vector<SizesType> sizes (
278294 tensor_shape.data (), tensor_shape.data () + tensor_shape.size ());
279295 kv_tensors.emplace_back (from_blob (
280296 io_mem_mgr_.get_k_caches_read_ptr (i),
@@ -284,7 +300,7 @@ Error Runner::generate(
284300 // outpus
285301 Result<TensorInfo> out_tensor_meta = method_meta->output_tensor_meta (i + 1 );
286302 tensor_shape = out_tensor_meta->sizes ();
287- sizes = std::vector<exec_aten:: SizesType>{
303+ sizes = std::vector<SizesType>{
288304 tensor_shape.data (), tensor_shape.data () + tensor_shape.size ()};
289305 kv_outputs.emplace_back (from_blob (
290306 io_mem_mgr_.get_k_caches_write_ptr (i),
@@ -303,7 +319,7 @@ Error Runner::generate(
303319 Result<TensorInfo> tensor_meta =
304320 method_meta->input_tensor_meta (input_index);
305321 auto tensor_shape = tensor_meta->sizes ();
306- std::vector<exec_aten:: SizesType> sizes (
322+ std::vector<SizesType> sizes (
307323 tensor_shape.data (), tensor_shape.data () + tensor_shape.size ());
308324
309325 kv_tensors.emplace_back (from_blob (
@@ -315,7 +331,7 @@ Error Runner::generate(
315331 Result<TensorInfo> out_tensor_meta =
316332 method_meta->output_tensor_meta (output_index);
317333 tensor_shape = out_tensor_meta->sizes ();
318- sizes = std::vector<exec_aten:: SizesType>{
334+ sizes = std::vector<SizesType>{
319335 tensor_shape.data (), tensor_shape.data () + tensor_shape.size ()};
320336
321337 kv_outputs.push_back (from_blob (
@@ -342,19 +358,18 @@ Error Runner::generate(
342358 auto logits_res = run_model_step (
343359 cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs);
344360 if (pos == num_prompt_tokens) {
345- stats_.first_token_ms = util:: time_in_ms ();
361+ stats_.first_token_ms = time_in_ms ();
346362 } else if (pos == num_prompt_tokens - 1 ) {
347- stats_.prompt_eval_end_ms = util:: time_in_ms ();
363+ stats_.prompt_eval_end_ms = time_in_ms ();
348364 }
349365
350366 ET_CHECK_OK_OR_RETURN_ERROR (logits_res.error ());
351- exec_aten:: Tensor& logits_tensor = logits_res.get ();
367+ Tensor& logits_tensor = logits_res.get ();
352368 prev_token = cur_token;
353- long sample_start_time_ms = util:: time_in_ms ();
369+ long sample_start_time_ms = time_in_ms ();
354370
355371 cur_token = logitsToToken<float >(logits_tensor);
356- stats_.aggregate_sampling_time_ms +=
357- util::time_in_ms () - sample_start_time_ms;
372+ stats_.aggregate_sampling_time_ms += time_in_ms () - sample_start_time_ms;
358373
359374 // advance the state machine
360375 if (pos < num_prompt_tokens - 1 ) {
@@ -381,7 +396,7 @@ Error Runner::generate(
381396 break ;
382397 }
383398 }
384- stats_.inference_end_ms = util:: time_in_ms ();
399+ stats_.inference_end_ms = time_in_ms ();
385400
386401 if (pos == seq_len) {
387402 ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
@@ -650,5 +665,4 @@ template bool Runner::getMetadataHelper<bool>(
650665 std::string method_name,
651666 bool default_val);
652667
653- } // namespace executor
654- } // namespace torch
668+ } // namespace example
0 commit comments