@@ -94,6 +94,11 @@ Result<uint64_t> MultimodalPrefiller::prefill(
9494 // `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
9595 // e.g. if start_pos = 2 and encoder_output.size(1) = 5,
9696 // cache_position_tensor should be [2, 3, 4, 5, 6].
97+ auto method_meta = ET_UNWRAP (module_->method_meta (kTextModelMethod ));
98+ auto first_input_info = ET_UNWRAP (method_meta.input_tensor_meta (0 ));
99+ auto sizes = first_input_info.sizes ();
100+ auto numel = sizes[0 ];
101+
97102 int64_t seq_len = encoder_output.toTensor ().size (1 );
98103 if (seq_len == 0 ) {
99104 ET_LOG (Error, " The encoder returned an empty output." );
@@ -103,10 +108,13 @@ Result<uint64_t> MultimodalPrefiller::prefill(
103108 for (int64_t i = 0 ; i < seq_len; ++i) {
104109 cache_positions[i] = start_pos + i;
105110 }
106- auto cache_position_tensor = ::executorch::extension::from_blob (
107- cache_positions.data (),
108- {static_cast <int >(seq_len)},
109- executorch::aten::ScalarType::Long);
111+ auto cache_position_tensor = (numel > 1 )
112+ ? ::executorch::extension::from_blob (
113+ cache_positions.data (),
114+ {static_cast <int >(seq_len)},
115+ executorch::aten::ScalarType::Long)
116+ : ::executorch::extension::from_blob (
117+ &start_pos, {1 }, executorch::aten::ScalarType::Long);
110118 auto prefill_result = module_->execute (
111119 kTextModelMethod , {cache_position_tensor, encoder_output});
112120 if (prefill_result.error () != ::executorch::runtime::Error::Ok) {
0 commit comments