@@ -91,30 +91,38 @@ Result<uint64_t> MultimodalPrefiller::prefill(
9191 }
9292
9393 // 2. Run decoder model for prefill.
94- // `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
95- // e.g. if start_pos = 2 and encoder_output.size(1) = 5,
96- // cache_position_tensor should be [2, 3, 4, 5, 6].
94+
95+ // Get expected shape of cache position tensor, which should be the second
96+ // argument
9797 auto method_meta = ET_UNWRAP (module_->method_meta (kTextModelMethod ));
98- auto first_input_info = ET_UNWRAP (method_meta.input_tensor_meta (0 ));
99- auto sizes = first_input_info .sizes ();
100- auto numel = sizes [0 ];
98+ auto second_input_info = ET_UNWRAP (method_meta.input_tensor_meta (1 ));
99+ auto second_input_sizes = second_input_info .sizes ();
100+ auto numel = second_input_sizes [0 ];
101101
102102 int64_t seq_len = encoder_output.toTensor ().size (1 );
103103 if (seq_len == 0 ) {
104104 ET_LOG (Error, " The encoder returned an empty output." );
105105 return ::executorch::runtime::Error::InvalidState;
106106 }
107- std::vector<int64_t > cache_positions (seq_len);
108- for (int64_t i = 0 ; i < seq_len; ++i) {
109- cache_positions[i] = start_pos + i;
107+
108+ executorch::extension::TensorPtr cache_position_tensor;
109+ if (numel > 1 ) {
110+ // `cache_position` goes from start_pos to start_pos +
111+ // encoder_output.size(1). e.g. if start_pos = 2 and encoder_output.size(1)
112+ // = 5, cache_position_tensor should be [2, 3, 4, 5, 6].
113+ std::vector<int64_t > cache_positions (seq_len);
114+ for (int64_t i = 0 ; i < seq_len; ++i) {
115+ cache_positions[i] = start_pos + i;
116+ }
117+ cache_position_tensor = ::executorch::extension::from_blob (
118+ cache_positions.data (),
119+ {static_cast <int >(seq_len)},
120+ executorch::aten::ScalarType::Long);
121+ } else {
122+ // Cache position is size 1.
123+ cache_position_tensor = ::executorch::extension::from_blob (
124+ &start_pos, {1 }, executorch::aten::ScalarType::Long);
110125 }
111- auto cache_position_tensor = (numel > 1 )
112- ? ::executorch::extension::from_blob (
113- cache_positions.data (),
114- {static_cast <int >(seq_len)},
115- executorch::aten::ScalarType::Long)
116- : ::executorch::extension::from_blob (
117- &start_pos, {1 }, executorch::aten::ScalarType::Long);
118126 auto prefill_result = module_->execute (
119127 kTextModelMethod , {cache_position_tensor, encoder_output});
120128 if (prefill_result.error () != ::executorch::runtime::Error::Ok) {
0 commit comments