@@ -27,6 +27,8 @@ using executorch::aten::SizesType;
27
27
using executorch::aten::StridesType;
28
28
using executorch::backends::aoti::aoti_torch_get_device_index;
29
29
using executorch::backends::aoti::aoti_torch_get_dtype;
30
+ using executorch::backends::aoti::aoti_torch_get_sizes;
31
+ using executorch::backends::aoti::aoti_torch_get_strides;
30
32
using executorch::backends::aoti::dtype_to_element_size;
31
33
using executorch::backends::aoti::dtype_to_scalar_type;
32
34
using executorch::backends::aoti::validate_storage_offset;
@@ -40,6 +42,67 @@ std::unordered_set<std::shared_ptr<Tensor>> tensors;
40
42
constexpr int32_t NOT_OWN = -1 ;
41
43
std::unordered_map<void *, int32_t > memory_to_n_tensor;
42
44
45
+ namespace {
46
+
47
+ // Calculate linear offset from strides and indices
48
+ int64_t calculate_linear_offset (
49
+ const int64_t * indices,
50
+ const int64_t * strides,
51
+ int64_t ndim) {
52
+ int64_t offset = 0 ;
53
+ for (int64_t i = 0 ; i < ndim; ++i) {
54
+ offset += indices[i] * strides[i];
55
+ }
56
+ return offset;
57
+ }
58
+
59
+ // Convert linear index to multi-dimensional indices based on sizes
60
+ void linear_to_indices (
61
+ int64_t linear_idx,
62
+ const int64_t * sizes,
63
+ int64_t ndim,
64
+ int64_t * indices) {
65
+ for (int64_t i = ndim - 1 ; i >= 0 ; --i) {
66
+ indices[i] = linear_idx % sizes[i];
67
+ linear_idx /= sizes[i];
68
+ }
69
+ }
70
+
71
+ // Generic pointwise copy function that handles arbitrary strides
72
+ template <typename T>
73
+ AOTITorchError pointwise_copy_generic (
74
+ T* dst_data,
75
+ const T* src_data,
76
+ const int64_t * dst_sizes,
77
+ const int64_t * dst_strides,
78
+ const int64_t * src_sizes,
79
+ const int64_t * src_strides,
80
+ int64_t dst_ndim,
81
+ int64_t src_ndim,
82
+ int64_t total_elements) {
83
+ std::vector<int64_t > dst_indices (dst_ndim);
84
+ std::vector<int64_t > src_indices (src_ndim);
85
+
86
+ for (int64_t linear_idx = 0 ; linear_idx < total_elements; ++linear_idx) {
87
+ // Convert linear index to multi-dimensional indices for both tensors
88
+ linear_to_indices (linear_idx, dst_sizes, dst_ndim, dst_indices.data ());
89
+ linear_to_indices (linear_idx, src_sizes, src_ndim, src_indices.data ());
90
+
91
+ // Calculate offsets for both source and destination
92
+ int64_t src_offset =
93
+ calculate_linear_offset (src_indices.data (), src_strides, src_ndim);
94
+ int64_t dst_offset =
95
+ calculate_linear_offset (dst_indices.data (), dst_strides, dst_ndim);
96
+
97
+ // Copy element
98
+ dst_data[dst_offset] = src_data[src_offset];
99
+ }
100
+
101
+ return Error::Ok;
102
+ }
103
+
104
+ } // anonymous namespace
105
+
43
106
extern " C" {
44
107
45
108
AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
@@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided(
178
241
}
179
242
int64_t nbytes = numel * element_size;
180
243
181
- if (device_type == 1 ) { // cuda
182
- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMallocManaged (&ptr, nbytes));
183
- } else if (device_type == 0 ) { // cpu
244
+ if (device_type == static_cast <int32_t >(SupportedDevices::CUDA)) {
245
+ ET_CUDA_CHECK_OR_RETURN_ERROR (
246
+ cudaMallocManaged (&ptr, static_cast <size_t >(nbytes)));
247
+ } else if (device_type == static_cast <int32_t >(SupportedDevices::CPU)) {
184
248
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
185
249
int result = posix_memalign (&ptr, 16 , nbytes);
186
250
if (result != 0 ) {
@@ -312,6 +376,207 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
312
376
return Error::Internal;
313
377
}
314
378
379
+ AOTITorchError
380
+ aoti_torch_copy_ (Tensor* self, Tensor* src, int32_t non_blocking) {
381
+ (void )non_blocking;
382
+
383
+ // Check for null pointers first
384
+ if (self == nullptr ) {
385
+ ET_LOG (Error, " aoti_torch_copy_ failed: self tensor is null" );
386
+ return Error::InvalidArgument;
387
+ }
388
+
389
+ if (src == nullptr ) {
390
+ ET_LOG (Error, " aoti_torch_copy_ failed: src tensor is null" );
391
+ return Error::InvalidArgument;
392
+ }
393
+
394
+ // Get dtype information and validate compatibility
395
+ int32_t self_dtype, src_dtype;
396
+ aoti_torch_get_dtype (self, &self_dtype);
397
+ aoti_torch_get_dtype (src, &src_dtype);
398
+
399
+ AOTITorchError self_dtype_error = validate_dtype (self_dtype);
400
+ if (self_dtype_error != Error::Ok) {
401
+ return self_dtype_error;
402
+ }
403
+
404
+ AOTITorchError src_dtype_error = validate_dtype (src_dtype);
405
+ if (src_dtype_error != Error::Ok) {
406
+ return src_dtype_error;
407
+ }
408
+
409
+ // Check dtype compatibility - both tensors must have the same dtype
410
+ if (self_dtype != src_dtype) {
411
+ ET_LOG (
412
+ Error,
413
+ " dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes" ,
414
+ self_dtype,
415
+ src_dtype);
416
+ return Error::InvalidArgument;
417
+ }
418
+
419
+ // Check total number of elements compatibility (PyTorch copy_ behavior)
420
+ int64_t self_numel = self->numel ();
421
+ int64_t src_numel = src->numel ();
422
+
423
+ if (self_numel != src_numel) {
424
+ ET_LOG (
425
+ Error,
426
+ " numel mismatch. self.numel()=%ld, src.numel()=%ld" ,
427
+ self_numel,
428
+ src_numel);
429
+ return Error::InvalidArgument;
430
+ }
431
+
432
+ // Get tensor metadata
433
+ int64_t * self_strides;
434
+ int64_t * src_strides;
435
+ aoti_torch_get_strides (self, &self_strides);
436
+ aoti_torch_get_strides (src, &src_strides);
437
+
438
+ int64_t * self_sizes;
439
+ int64_t * src_sizes;
440
+ aoti_torch_get_sizes (self, &self_sizes);
441
+ aoti_torch_get_sizes (src, &src_sizes);
442
+
443
+ // Determine device locations
444
+ cudaPointerAttributes srcAttributes{};
445
+ cudaPointerAttributes dstAttributes{};
446
+
447
+ ET_CUDA_CHECK_OR_RETURN_ERROR (
448
+ cudaPointerGetAttributes (&srcAttributes, src->data_ptr ()));
449
+
450
+ ET_CUDA_CHECK_OR_RETURN_ERROR (
451
+ cudaPointerGetAttributes (&dstAttributes, self->data_ptr ()));
452
+
453
+ bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
454
+ bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
455
+
456
+ // Check if tensors have the same schema (sizes, strides, dtype) for fast path
457
+ bool same_schema = true ;
458
+ for (int i = 0 ; i < self->dim (); i++) {
459
+ if (self_strides[i] != src_strides[i]) {
460
+ same_schema = false ;
461
+ break ;
462
+ }
463
+ }
464
+
465
+ size_t total_bytes = src->nbytes ();
466
+ int64_t total_elements = self->numel ();
467
+
468
+ if (same_schema) {
469
+ // Fast path: Direct memory copy since layouts match exactly
470
+ if (srcIsDevice && dstIsDevice) {
471
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
472
+ self->mutable_data_ptr (),
473
+ src->data_ptr (),
474
+ total_bytes,
475
+ cudaMemcpyDeviceToDevice));
476
+ } else if (srcIsDevice && !dstIsDevice) {
477
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
478
+ self->mutable_data_ptr (),
479
+ src->data_ptr (),
480
+ total_bytes,
481
+ cudaMemcpyDeviceToHost));
482
+ } else if (!srcIsDevice && dstIsDevice) {
483
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
484
+ self->mutable_data_ptr (),
485
+ src->data_ptr (),
486
+ total_bytes,
487
+ cudaMemcpyHostToDevice));
488
+ } else {
489
+ std::memcpy (self->mutable_data_ptr (), src->data_ptr (), total_bytes);
490
+ }
491
+ } else {
492
+ // Fallback path: Pointwise copy with stride-aware indexing
493
+ // This handles arbitrary tensor layouts and strides
494
+
495
+ size_t element_size = dtype_to_element_size (self_dtype);
496
+ if (element_size == 0 ) {
497
+ ET_LOG (Error, " Invalid element size for dtype: %d" , self_dtype);
498
+ return Error::InvalidArgument;
499
+ }
500
+
501
+ // Allocate temporary host memory for GPU tensors
502
+ float * src_host_data = nullptr ;
503
+ float * dst_host_data = nullptr ;
504
+ bool need_free_src = false ;
505
+ bool need_free_dst = false ;
506
+
507
+ if (srcIsDevice) {
508
+ src_host_data =
509
+ static_cast <float *>(malloc (total_elements * sizeof (float )));
510
+ if (src_host_data == nullptr ) {
511
+ ET_LOG (Error, " Failed to allocate memory for src_host_data" );
512
+ return Error::MemoryAllocationFailed;
513
+ }
514
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
515
+ src_host_data, src->data_ptr (), total_bytes, cudaMemcpyDeviceToHost));
516
+ need_free_src = true ;
517
+ } else {
518
+ src_host_data = static_cast <float *>(src->data_ptr ());
519
+ }
520
+
521
+ if (dstIsDevice) {
522
+ dst_host_data =
523
+ static_cast <float *>(malloc (total_elements * sizeof (float )));
524
+ if (dst_host_data == nullptr ) {
525
+ ET_LOG (Error, " Failed to allocate memory for dst_host_data" );
526
+ if (need_free_src) {
527
+ free (src_host_data);
528
+ }
529
+ return Error::MemoryAllocationFailed;
530
+ }
531
+ need_free_dst = true ;
532
+ } else {
533
+ dst_host_data = static_cast <float *>(self->mutable_data_ptr ());
534
+ }
535
+
536
+ // Perform pointwise copy with stride calculation
537
+ AOTITorchError copy_err = pointwise_copy_generic (
538
+ dst_host_data,
539
+ src_host_data,
540
+ self_sizes,
541
+ self_strides,
542
+ src_sizes,
543
+ src_strides,
544
+ self->dim (),
545
+ src->dim (),
546
+ total_elements);
547
+
548
+ if (copy_err != Error::Ok) {
549
+ // Clean up temporary buffers before returning
550
+ if (need_free_src) {
551
+ free (src_host_data);
552
+ }
553
+ if (need_free_dst) {
554
+ free (dst_host_data);
555
+ }
556
+ return copy_err;
557
+ }
558
+
559
+ // Copy result back to device if needed
560
+ if (dstIsDevice) {
561
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
562
+ self->mutable_data_ptr (),
563
+ dst_host_data,
564
+ total_bytes,
565
+ cudaMemcpyHostToDevice));
566
+ }
567
+
568
+ // Clean up temporary buffers
569
+ if (need_free_src) {
570
+ free (src_host_data);
571
+ }
572
+ if (need_free_dst) {
573
+ free (dst_host_data);
574
+ }
575
+ }
576
+
577
+ return Error::Ok;
578
+ }
579
+
315
580
AOTITorchError aoti_torch__reinterpret_tensor (
316
581
Tensor* self,
317
582
int64_t ndim,
0 commit comments