Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
strategy:
fail-fast: false
matrix:
model: [linear, add, add_mul, resnet18, conv1d]
model: [linear, add, add_mul, resnet18, conv1d, sdpa]
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
Expand Down
12 changes: 4 additions & 8 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ int32_t aoti_torch_layout_strided() {
}

// Dtype constants - these return the PyTorch dtype codes
int32_t aoti_torch_dtype_float16() {
return 5; // PyTorch's float16 dtype code
}

int32_t aoti_torch_dtype_float32() {
return 6; // PyTorch's float32 dtype code
}
Expand Down Expand Up @@ -238,14 +242,6 @@ aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) {
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle) {
(void)orig_handle;
(void)new_handle;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
void* data_ptr,
int64_t ndim,
Expand Down
6 changes: 3 additions & 3 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);

// Utility functions for device and layout information

AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();

// Dtype utility function needed by Metal backend
AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype);
Expand All @@ -94,9 +97,6 @@ aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);

AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
void* data_ptr,
int64_t ndim,
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
return executorch::aten::ScalarType::Int;
case 4: // PyTorch's int64 dtype code
return executorch::aten::ScalarType::Long;
case 5: // PyTorch's float16 (half) dtype code
return executorch::aten::ScalarType::Half;
case 6: // PyTorch's float32 dtype code
return executorch::aten::ScalarType::Float;
case 11: // PyTorch's bool dtype code
Expand Down
31 changes: 31 additions & 0 deletions backends/cuda/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ runtime.python_library(
"//executorch/...",
],
deps = [
":triton_replacement_pass",
"//caffe2:torch",
"//executorch/backends/aoti/passes:passes",
"//executorch/exir/_serialize:lib",
Expand All @@ -32,3 +33,33 @@ runtime.python_library(
"//executorch/backends/aoti:aoti_partitioner",
],
)

runtime.python_library(
name = "triton_kernels",
srcs = [
"triton/kernels/__init__.py",
"triton/kernels/optimized_sdpa.py",
],
visibility = [
"//executorch/backends/cuda/...",
],
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
name = "triton_replacement_pass",
srcs = [
"triton/__init__.py",
"triton/replacement_pass.py",
],
visibility = [
"//executorch/...",
],
deps = [
":triton_kernels",
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)
15 changes: 9 additions & 6 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
ReplaceViewCopyWithViewPass,
)

from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import (
Expand All @@ -27,7 +31,7 @@
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.decomposition import conv1d_to_conv2d
from torch.export.passes import move_to_device_pass
from torch.nn.attention import SDPBackend


cuda_decomposition_table = {
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
Expand Down Expand Up @@ -127,6 +131,9 @@ def preprocess( # noqa: C901
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)

# Replace aten ops with triton ops
ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module)

cuda_edge_program = cuda_edge_program.run_decompositions(
cuda_decomposition_table
)
Expand Down Expand Up @@ -188,11 +195,7 @@ def preprocess( # noqa: C901
}
)

with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
[
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
]
), torch.no_grad():
with collect_unsupported_fallback_kernels(), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
# Here we should expect 1 so file and 1 weight blob in the same directory.
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
Expand Down
90 changes: 90 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,96 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
return Error::Ok;
}

AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle) {
// Validate input parameters
ET_CHECK_OR_RETURN_ERROR(
orig_handle != nullptr,
InvalidArgument,
"aoti_torch_new_tensor_handle failed: orig_handle is null");

ET_CHECK_OR_RETURN_ERROR(
new_handle != nullptr,
InvalidArgument,
"aoti_torch_new_tensor_handle failed: new_handle is null");

// Get metadata from the original tensor
int64_t* sizes_ptr;
int64_t* strides_ptr;
int32_t dtype;
int32_t device_type;
int32_t device_index;

ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr));
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_get_strides(orig_handle, &strides_ptr));
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype));
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_get_device_type(orig_handle, &device_type));
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_get_device_index(orig_handle, &device_index));

int64_t ndim = orig_handle->dim();

// Validate dtype
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));

// Ensure device_index is always 0
ET_CHECK_OR_RETURN_ERROR(
device_index == 0,
InvalidArgument,
"device_index must be 0, got: %d",
device_index);

// Get the original data pointer from the source tensor
void* data_ptr = orig_handle->mutable_data_ptr();
ET_CHECK_OR_RETURN_ERROR(
data_ptr != nullptr,
InvalidArgument,
"Source tensor has null data pointer");

// Check if the given memory is in the map
auto memory_it = memory_to_n_tensor.find(data_ptr);
ET_CHECK_OR_RETURN_ERROR(
memory_it != memory_to_n_tensor.end(),
InvalidArgument,
"Memory address %p is not being tracked by reference counting system",
data_ptr);

// Convert sizes and strides to vectors
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
std::vector<StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create new tensor that shares the same memory as the original
// This is similar to PyTorch's Tensor copy constructor - creates a new
// tensor object that shares the same underlying storage
std::shared_ptr<Tensor> tensor = make_tensor(
sizes, // Same sizes as original
data_ptr, // Share the same memory from source tensor
{}, // dim_order (empty, will be auto-generated)
strides, // Same strides as original
dtype_to_scalar_type(dtype) // Same dtype as original
);

ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr, InvalidArgument, "Failed to create new tensor handle");

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*new_handle = tensor.get();

// Increment the reference count for this memory address only if it is owned
// by tensor
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
? NOT_OWN
: memory_to_n_tensor[data_ptr] + 1;

return Error::Ok;
}

AOTITorchError aoti_torch__reinterpret_tensor(
Tensor* self,
int64_t ndim,
Expand Down
25 changes: 25 additions & 0 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,31 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
int64_t storage_offset,
Tensor** ret_new_tensor);

/**
* Creates a new tensor handle from an existing one.
*
* This function creates a new tensor object that shares the same underlying
* memory as the original tensor. Similar to PyTorch's Tensor copy constructor,
* it creates a new handle/reference to the same data without performing a deep
* copy.
*
* The new tensor will:
* - Share the same memory/storage as the original tensor
* - Have the same shape, strides, and dtype as the original
* - Increment the reference count for the underlying memory (if owned)
*
* @param orig_handle Original tensor to create a new handle from (must not be
* null)
* @param new_handle Output pointer to store the new tensor handle (must not be
* null)
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or invalid parameters
*/
AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);

/**
* Copies data from source tensor to destination tensor.
*
Expand Down
5 changes: 4 additions & 1 deletion backends/cuda/runtime/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ enum class SupportedDTypes : int32_t {
INT16 = 2, // PyTorch's int16 dtype code
INT32 = 3, // PyTorch's int32 dtype code
INT64 = 4, // PyTorch's int64 dtype code
FLOAT16 = 5, // PyTorch's float16 dtype code
FLOAT32 = 6, // PyTorch's float32 dtype code
BOOL = 11, // PyTorch's bool dtype code
BFLOAT16 = 15, // PyTorch's bfloat16 dtype code
Expand All @@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
case static_cast<int32_t>(SupportedDTypes::INT16):
case static_cast<int32_t>(SupportedDTypes::INT32):
case static_cast<int32_t>(SupportedDTypes::INT64):
case static_cast<int32_t>(SupportedDTypes::FLOAT16):
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
case static_cast<int32_t>(SupportedDTypes::BOOL):
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
Expand All @@ -98,12 +100,13 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
ET_CHECK_OR_RETURN_ERROR(
is_dtype_supported_in_et_cuda(dtype),
InvalidArgument,
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d (bfloat16)",
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float16), %d (float32), %d (bool), %d (bfloat16)",
dtype,
static_cast<int32_t>(SupportedDTypes::INT8),
static_cast<int32_t>(SupportedDTypes::INT16),
static_cast<int32_t>(SupportedDTypes::INT32),
static_cast<int32_t>(SupportedDTypes::INT64),
static_cast<int32_t>(SupportedDTypes::FLOAT16),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDTypes::BOOL),
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
Expand Down
43 changes: 43 additions & 0 deletions backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")

def test_sdpa_single_kernel(self):
"""
Test CUDA export for model containing single SDPA kernel.

SDPA: Scaled Dot Product Attention
"""

class SDPAModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, query, key, value):
out = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
return out

module = SDPAModule()
module.eval()

# Create input tensors (batch, num_heads, seq_len, head_dim)
batch_size = 2
num_heads = 8
seq_len = 128
head_dim = 64

query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
inputs = (query, key, value)

# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(
edge_program_manager,
"SDPA single kernel operation export failed",
)
17 changes: 17 additions & 0 deletions backends/cuda/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Import all kernels to ensure @triton_op decorators are executed
# and ops are registered to torch.ops.triton namespace
from executorch.backends.cuda.triton import kernels # noqa: F401

from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)

__all__ = [
"ReplaceEdgeOpWithTritonOpPass",
]
11 changes: 11 additions & 0 deletions backends/cuda/triton/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.cuda.triton.kernels.sdpa import sdpa

__all__ = [
"sdpa",
]
Loading
Loading