@@ -161,6 +161,83 @@ inline bool is_contiguous_tensor(
161161 return true ;
162162}
163163
164+ // Check if any dimension has a broadcast stride (stride = 0 with size > 1)
165+ // In PyTorch, stride=0 means the same data is repeated along that dimension
166+ inline bool has_broadcast_strides (
167+ int64_t ndim,
168+ const int64_t * sizes_ptr,
169+ const int64_t * strides_ptr) {
170+ if (strides_ptr == nullptr ) {
171+ return false ;
172+ }
173+ for (int64_t i = 0 ; i < ndim; i++) {
174+ if (strides_ptr[i] == 0 && sizes_ptr[i] > 1 ) {
175+ return true ;
176+ }
177+ }
178+ return false ;
179+ }
180+
181+ // Materialize strided tensor data to a contiguous buffer
182+ // Copies element-by-element using the source strides to compute offsets
183+ // src_data: pointer to source data (may be offset from base)
184+ // dst_data: pointer to destination buffer (must be pre-allocated)
185+ // ndim: number of dimensions
186+ // sizes_ptr: tensor sizes
187+ // strides_ptr: source strides (must not contain broadcast strides)
188+ // element_size: size of each element in bytes
189+ // Returns the number of bytes written
190+ inline size_t materialize_strided (
191+ const void * src_data,
192+ void * dst_data,
193+ int64_t ndim,
194+ const int64_t * sizes_ptr,
195+ const int64_t * strides_ptr,
196+ size_t element_size) {
197+ if (ndim == 0 ) {
198+ // Scalar case
199+ std::memcpy (dst_data, src_data, element_size);
200+ return element_size;
201+ }
202+
203+ // Calculate total elements in output
204+ int64_t total_elements = 1 ;
205+ for (int64_t i = 0 ; i < ndim; i++) {
206+ total_elements *= sizes_ptr[i];
207+ }
208+
209+ // Calculate contiguous output strides
210+ std::vector<int64_t > out_strides (ndim);
211+ out_strides[ndim - 1 ] = 1 ;
212+ for (int64_t i = ndim - 2 ; i >= 0 ; i--) {
213+ out_strides[i] = out_strides[i + 1 ] * sizes_ptr[i + 1 ];
214+ }
215+
216+ // Copy each element, computing source offset based on input strides
217+ const char * src = static_cast <const char *>(src_data);
218+ char * dst = static_cast <char *>(dst_data);
219+
220+ for (int64_t linear_idx = 0 ; linear_idx < total_elements; linear_idx++) {
221+ // Convert linear index to multi-dimensional index and compute source offset
222+ int64_t src_offset = 0 ;
223+ int64_t remaining = linear_idx;
224+
225+ for (int64_t dim = 0 ; dim < ndim; dim++) {
226+ int64_t coord = remaining / out_strides[dim];
227+ remaining = remaining % out_strides[dim];
228+ src_offset += coord * strides_ptr[dim];
229+ }
230+
231+ // Copy one element
232+ std::memcpy (
233+ dst + linear_idx * element_size,
234+ src + src_offset * element_size,
235+ element_size);
236+ }
237+
238+ return total_elements * element_size;
239+ }
240+
164241} // namespace aoti
165242} // namespace backends
166243} // namespace executorch
0 commit comments