@@ -161,6 +161,105 @@ inline bool is_contiguous_tensor(
161161 return true ;
162162}
163163
164+ // Check if any dimension has a broadcast stride (stride = 0 with size > 1)
165+ // In PyTorch, stride=0 means the same data is repeated along that dimension
166+ inline bool has_broadcast_strides (
167+ int64_t ndim,
168+ const int64_t * sizes_ptr,
169+ const int64_t * strides_ptr) {
170+ if (strides_ptr == nullptr ) {
171+ return false ;
172+ }
173+ for (int64_t i = 0 ; i < ndim; i++) {
174+ if (strides_ptr[i] == 0 && sizes_ptr[i] > 1 ) {
175+ return true ;
176+ }
177+ }
178+ return false ;
179+ }
180+
181+ // Materialize broadcast data by copying with expansion
182+ // src_data: pointer to source data
183+ // dst_data: pointer to destination buffer (must be pre-allocated)
184+ // ndim: number of dimensions
185+ // sizes_ptr: target sizes
186+ // strides_ptr: source strides (may contain zeros for broadcast dims)
187+ // element_size: size of each element in bytes
188+ // Returns the number of bytes written
189+ inline size_t materialize_broadcast (
190+ const void * src_data,
191+ void * dst_data,
192+ int64_t ndim,
193+ const int64_t * sizes_ptr,
194+ const int64_t * strides_ptr,
195+ size_t element_size) {
196+ if (ndim == 0 ) {
197+ // Scalar case
198+ std::memcpy (dst_data, src_data, element_size);
199+ return element_size;
200+ }
201+
202+ // Calculate total elements in output
203+ int64_t total_elements = 1 ;
204+ for (int64_t i = 0 ; i < ndim; i++) {
205+ total_elements *= sizes_ptr[i];
206+ }
207+
208+ // Calculate contiguous output strides
209+ std::vector<int64_t > out_strides (ndim);
210+ out_strides[ndim - 1 ] = 1 ;
211+ for (int64_t i = ndim - 2 ; i >= 0 ; i--) {
212+ out_strides[i] = out_strides[i + 1 ] * sizes_ptr[i + 1 ];
213+ }
214+
215+ // Copy each element, computing source offset based on input strides
216+ const char * src = static_cast <const char *>(src_data);
217+ char * dst = static_cast <char *>(dst_data);
218+
219+ for (int64_t linear_idx = 0 ; linear_idx < total_elements; linear_idx++) {
220+ // Convert linear index to multi-dimensional index and compute source offset
221+ int64_t src_offset = 0 ;
222+ int64_t remaining = linear_idx;
223+
224+ for (int64_t dim = 0 ; dim < ndim; dim++) {
225+ int64_t coord = remaining / out_strides[dim];
226+ remaining = remaining % out_strides[dim];
227+ // Use source stride (which may be 0 for broadcast dimensions)
228+ src_offset += coord * strides_ptr[dim];
229+ }
230+
231+ // Copy one element
232+ std::memcpy (
233+ dst + linear_idx * element_size,
234+ src + src_offset * element_size,
235+ element_size);
236+ }
237+
238+ return total_elements * element_size;
239+ }
240+
241+ // Materialize strided tensor data to a contiguous buffer
242+ // Similar to materialize_broadcast but works with any strides (not just broadcast)
243+ // src_data: pointer to source data (may be offset from base)
244+ // dst_data: pointer to destination buffer (must be pre-allocated)
245+ // ndim: number of dimensions
246+ // sizes_ptr: tensor sizes
247+ // strides_ptr: source strides
248+ // element_size: size of each element in bytes
249+ // Returns the number of bytes written
250+ inline size_t materialize_strided (
251+ const void * src_data,
252+ void * dst_data,
253+ int64_t ndim,
254+ const int64_t * sizes_ptr,
255+ const int64_t * strides_ptr,
256+ size_t element_size) {
257+ // This function handles the general strided case
258+ // It's essentially the same as materialize_broadcast
259+ return materialize_broadcast (
260+ src_data, dst_data, ndim, sizes_ptr, strides_ptr, element_size);
261+ }
262+
164263} // namespace aoti
165264} // namespace backends
166265} // namespace executorch
0 commit comments