@@ -327,29 +327,84 @@ inline TensorPtr make_tensor_ptr(
327327 * Creates a TensorPtr to manage a new Tensor with the same properties
328328 * as the given Tensor, sharing the same data without owning it.
329329 *
330- * @param tensor The Tensor whose properties are used to create a new TensorPtr.
331- * @return A new TensorPtr managing a Tensor with the same properties as the
332- * original.
330+ * If an override is provided (non-empty), it is passed as-is. If an override is
331+ * empty, the corresponding metadata is reused from the source tensor when it
332+ * fits; otherwise it is left empty for the core factory to derive a valid
333+ * configuration. If `dim_order` is empty but `strides` is provided, `dim_order`
334+ * is left empty so the core may infer it from the provided strides.
335+ *
336+ * @param tensor The source tensor to alias.
337+ * @param sizes Optional sizes override.
338+ * @param dim_order Optional dimension order override.
339+ * @param strides Optional strides override.
340+ * @param deleter A custom deleter function for managing the lifetime of the
341+ * original Tensor.
342+ * @return A TensorPtr aliasing the same storage with requested metadata.
333343 */
334- inline TensorPtr make_tensor_ptr (const executorch::aten::Tensor& tensor) {
344+ inline TensorPtr make_tensor_ptr (
345+ const executorch::aten::Tensor& tensor,
346+ std::vector<executorch::aten::SizesType> sizes = {},
347+ std::vector<executorch::aten::DimOrderType> dim_order = {},
348+ std::vector<executorch::aten::StridesType> strides = {},
349+ std::function<void (void *)> deleter = nullptr) {
350+ if (sizes.empty ()) {
351+ sizes.assign (tensor.sizes ().begin (), tensor.sizes ().end ());
352+ }
353+ const auto same_rank = sizes.size () == static_cast <size_t >(tensor.dim ());
354+ const auto same_shape = same_rank &&
355+ std::equal (sizes.begin (), sizes.end (), tensor.sizes ().begin ());
356+ const auto element_count =
357+ executorch::aten::compute_numel (sizes.data (), sizes.size ());
358+ const auto parent_element_count = tensor.numel ();
359+ ET_CHECK_MSG (
360+ element_count <= parent_element_count,
361+ " Requested view has %zd elements, but source tensor only has %zd." ,
362+ static_cast <ssize_t >(element_count),
363+ static_cast <ssize_t >(parent_element_count));
364+ #ifndef USE_ATEN_LIB
365+ if (dim_order.empty () && strides.empty () && same_rank) {
366+ dim_order.assign (tensor.dim_order ().begin (), tensor.dim_order ().end ());
367+ }
368+ #endif // USE_ATEN_LIB
369+ if (strides.empty () && dim_order.empty () && same_shape) {
370+ strides.assign (tensor.strides ().begin (), tensor.strides ().end ());
371+ }
335372 return make_tensor_ptr (
336373 std::vector<executorch::aten::SizesType>(
337374 tensor.sizes ().begin (), tensor.sizes ().end ()),
338375 tensor.mutable_data_ptr (),
339- #ifndef USE_ATEN_LIB
340- std::vector<executorch::aten::DimOrderType>(
341- tensor.dim_order ().begin (), tensor.dim_order ().end ()),
342- std::vector<executorch::aten::StridesType>(
343- tensor.strides ().begin (), tensor.strides ().end ()),
376+ std::move (dim_order),
377+ std::move (strides),
344378 tensor.scalar_type (),
345- tensor.shape_dynamism ()
379+ #ifndef USE_ATEN_LIB
380+ tensor.shape_dynamism (),
346381#else // USE_ATEN_LIB
347- {},
348- std::vector<executorch::aten::StridesType>(
349- tensor.strides ().begin (), tensor.strides ().end ()),
350- tensor.scalar_type ()
382+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
351383#endif // USE_ATEN_LIB
352- );
384+ std::move (deleter));
385+ }
386+
387+ /* *
388+ * Convenience overload identical to make_tensor_ptr(*tensor_ptr, ...).
389+ * Keeps the original TensorPtr alive until the returned TensorPtr is destroyed.
390+ *
391+ * @param tensor_ptr The source tensor pointer to alias.
392+ * @param sizes Optional sizes override.
393+ * @param dim_order Optional dimension order override.
394+ * @param strides Optional strides override.
395+ * @return A TensorPtr aliasing the same storage with requested metadata.
396+ */
397+ inline TensorPtr make_tensor_ptr (
398+ const TensorPtr& tensor_ptr,
399+ std::vector<executorch::aten::SizesType> sizes = {},
400+ std::vector<executorch::aten::DimOrderType> dim_order = {},
401+ std::vector<executorch::aten::StridesType> strides = {}) {
402+ return make_tensor_ptr (
403+ *tensor_ptr,
404+ std::move (sizes),
405+ std::move (dim_order),
406+ std::move (strides),
407+ [tensor_ptr](void *) {});
353408}
354409
355410/* *
0 commit comments