Skip to content

Commit c225f96

Browse files
committed
inti
1 parent 3dbc15b commit c225f96

File tree

11 files changed

+880
-7
lines changed

11 files changed

+880
-7
lines changed

backends/aoti/common_shims.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ int32_t aoti_torch_layout_strided() {
164164
}
165165

166166
// Dtype constants - these return the PyTorch dtype codes
167+
int32_t aoti_torch_dtype_float16() {
168+
return 5; // PyTorch's float16 dtype code
169+
}
170+
167171
int32_t aoti_torch_dtype_float32() {
168172
return 6; // PyTorch's float32 dtype code
169173
}

backends/aoti/common_shims.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
5757
// Utility functions for device and layout information
5858
int32_t aoti_torch_device_type_cpu();
5959
int32_t aoti_torch_layout_strided();
60+
int32_t aoti_torch_dtype_float16();
6061
int32_t aoti_torch_dtype_float32();
6162
int32_t aoti_torch_dtype_bfloat16();
6263
int32_t aoti_torch_dtype_int8();

backends/aoti/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
4343
return executorch::aten::ScalarType::Int;
4444
case 4: // PyTorch's int64 dtype code
4545
return executorch::aten::ScalarType::Long;
46+
case 5: // PyTorch's float16 (half) dtype code
47+
return executorch::aten::ScalarType::Half;
4648
case 6: // PyTorch's float32 dtype code
4749
return executorch::aten::ScalarType::Float;
4850
case 11: // PyTorch's bool dtype code

backends/cuda/cuda_backend.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,7 @@ def preprocess(
162162
"max_autotune_conv_backends": "TRITON",
163163
}
164164

165-
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
166-
[
167-
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
168-
]
169-
), torch.no_grad():
165+
with collect_unsupported_fallback_kernels(), torch.no_grad():
170166
# torch._logging.set_logs(post_grad_graphs=True)
171167
# Here we should expect 1 so file and 1 weight blob in the same directory.
172168
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]

backends/cuda/runtime/shims/memory.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,96 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
582582
return Error::Ok;
583583
}
584584

585+
AOTITorchError aoti_torch_new_tensor_handle(
586+
Tensor* orig_handle,
587+
Tensor** new_handle) {
588+
// Validate input parameters
589+
ET_CHECK_OR_RETURN_ERROR(
590+
orig_handle != nullptr,
591+
InvalidArgument,
592+
"aoti_torch_new_tensor_handle failed: orig_handle is null");
593+
594+
ET_CHECK_OR_RETURN_ERROR(
595+
new_handle != nullptr,
596+
InvalidArgument,
597+
"aoti_torch_new_tensor_handle failed: new_handle is null");
598+
599+
// Get metadata from the original tensor
600+
int64_t* sizes_ptr;
601+
int64_t* strides_ptr;
602+
int32_t dtype;
603+
int32_t device_type;
604+
int32_t device_index;
605+
606+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr));
607+
ET_CHECK_OK_OR_RETURN_ERROR(
608+
aoti_torch_get_strides(orig_handle, &strides_ptr));
609+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype));
610+
ET_CHECK_OK_OR_RETURN_ERROR(
611+
aoti_torch_get_device_type(orig_handle, &device_type));
612+
ET_CHECK_OK_OR_RETURN_ERROR(
613+
aoti_torch_get_device_index(orig_handle, &device_index));
614+
615+
int64_t ndim = orig_handle->dim();
616+
617+
// Validate dtype
618+
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));
619+
620+
// Ensure device_index is always 0
621+
ET_CHECK_OR_RETURN_ERROR(
622+
device_index == 0,
623+
InvalidArgument,
624+
"device_index must be 0, got: %d",
625+
device_index);
626+
627+
// Get the original data pointer from the source tensor
628+
void* data_ptr = orig_handle->mutable_data_ptr();
629+
ET_CHECK_OR_RETURN_ERROR(
630+
data_ptr != nullptr,
631+
InvalidArgument,
632+
"Source tensor has null data pointer");
633+
634+
// Check if the given memory is in the map
635+
auto memory_it = memory_to_n_tensor.find(data_ptr);
636+
ET_CHECK_OR_RETURN_ERROR(
637+
memory_it != memory_to_n_tensor.end(),
638+
InvalidArgument,
639+
"Memory address %p is not being tracked by reference counting system",
640+
data_ptr);
641+
642+
// Convert sizes and strides to vectors
643+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
644+
std::vector<StridesType> strides =
645+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
646+
647+
// Create new tensor that shares the same memory as the original
648+
// This is similar to PyTorch's Tensor copy constructor - creates a new
649+
// tensor object that shares the same underlying storage
650+
std::shared_ptr<Tensor> tensor = make_tensor(
651+
sizes, // Same sizes as original
652+
data_ptr, // Share the same memory from source tensor
653+
{}, // dim_order (empty, will be auto-generated)
654+
strides, // Same strides as original
655+
dtype_to_scalar_type(dtype) // Same dtype as original
656+
);
657+
658+
ET_CHECK_OR_RETURN_ERROR(
659+
tensor != nullptr, InvalidArgument, "Failed to create new tensor handle");
660+
661+
// Store the tensor so it doesn't get destroyed
662+
tensors.insert(tensor);
663+
664+
*new_handle = tensor.get();
665+
666+
// Increment the reference count for this memory address only if it is owned
667+
// by tensor
668+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
669+
? NOT_OWN
670+
: memory_to_n_tensor[data_ptr] + 1;
671+
672+
return Error::Ok;
673+
}
674+
585675
AOTITorchError aoti_torch__reinterpret_tensor(
586676
Tensor* self,
587677
int64_t ndim,

backends/cuda/runtime/shims/memory.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,31 @@ AOTITorchError aoti_torch__reinterpret_tensor(
114114
int64_t storage_offset,
115115
Tensor** ret_new_tensor);
116116

117+
/**
118+
* Creates a new tensor handle from an existing one.
119+
*
120+
* This function creates a new tensor object that shares the same underlying
121+
* memory as the original tensor. Similar to PyTorch's Tensor copy constructor,
122+
* it creates a new handle/reference to the same data without performing a deep
123+
* copy.
124+
*
125+
* The new tensor will:
126+
* - Share the same memory/storage as the original tensor
127+
* - Have the same shape, strides, and dtype as the original
128+
* - Increment the reference count for the underlying memory (if owned)
129+
*
130+
* @param orig_handle Original tensor to create a new handle from (must not be
131+
* null)
132+
* @param new_handle Output pointer to store the new tensor handle (must not be
133+
* null)
134+
*
135+
* @return Error::Ok on success, appropriate error code on failure:
136+
* - Error::InvalidArgument: null pointers or invalid parameters
137+
*/
138+
AOTITorchError aoti_torch_new_tensor_handle(
139+
Tensor* orig_handle,
140+
Tensor** new_handle);
141+
117142
/**
118143
* Copies data from source tensor to destination tensor.
119144
*

backends/cuda/runtime/utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ enum class SupportedDTypes : int32_t {
6161
INT16 = 2, // PyTorch's int16 dtype code
6262
INT32 = 3, // PyTorch's int32 dtype code
6363
INT64 = 4, // PyTorch's int64 dtype code
64+
FLOAT16 = 5, // PyTorch's float16 dtype code
6465
FLOAT32 = 6, // PyTorch's float32 dtype code
6566
BOOL = 11, // PyTorch's bool dtype code
6667
BFLOAT16 = 15, // PyTorch's bfloat16 dtype code
@@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
8485
case static_cast<int32_t>(SupportedDTypes::INT16):
8586
case static_cast<int32_t>(SupportedDTypes::INT32):
8687
case static_cast<int32_t>(SupportedDTypes::INT64):
88+
case static_cast<int32_t>(SupportedDTypes::FLOAT16):
8789
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
8890
case static_cast<int32_t>(SupportedDTypes::BOOL):
8991
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
@@ -98,12 +100,13 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
98100
ET_CHECK_OR_RETURN_ERROR(
99101
is_dtype_supported_in_et_cuda(dtype),
100102
InvalidArgument,
101-
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d (bfloat16)",
103+
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float16), %d (float32), %d (bool), %d (bfloat16)",
102104
dtype,
103105
static_cast<int32_t>(SupportedDTypes::INT8),
104106
static_cast<int32_t>(SupportedDTypes::INT16),
105107
static_cast<int32_t>(SupportedDTypes::INT32),
106108
static_cast<int32_t>(SupportedDTypes::INT64),
109+
static_cast<int32_t>(SupportedDTypes::FLOAT16),
107110
static_cast<int32_t>(SupportedDTypes::FLOAT32),
108111
static_cast<int32_t>(SupportedDTypes::BOOL),
109112
static_cast<int32_t>(SupportedDTypes::BFLOAT16));

0 commit comments

Comments
 (0)