Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# CUDA-specific AOTI functionality
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
set(_aoti_cuda_sources
runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
runtime/shims/cuda_guard.cpp
)
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ runtime.cxx_library(
name = "runtime_shims",
srcs = [
"guard.cpp",
"shims/cuda_guard.cpp",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/cuda_guard.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
109 changes: 109 additions & 0 deletions backends/cuda/runtime/shims/cuda_guard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/shims/cuda_guard.h>

namespace executorch {
namespace backends {
namespace cuda {

extern "C" {

AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard) {
ET_CHECK_OR_RETURN_ERROR(
ret_guard != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_guard failed: ret_guard is null");

auto result = CUDAGuard::create(device_index);
if (!result.ok()) {
return result.error();
}
*ret_guard = new CUDAGuard(std::move(result.get()));
return Error::Ok;
}

AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_delete_cuda_guard failed: guard is null");

delete guard;
return Error::Ok;
}

AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_cuda_guard_set_index failed: guard is null");

ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index));
return Error::Ok;
}

AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard) {
ET_CHECK_OR_RETURN_ERROR(
ret_guard != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_stream_guard failed: ret_guard is null");

ET_CHECK_OR_RETURN_ERROR(
stream != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_stream_guard failed: stream is null");

auto result =
CUDAStreamGuard::create(static_cast<cudaStream_t>(stream), device_index);
if (!result.ok()) {
return result.error();
}
*ret_guard = new CUDAStreamGuard(std::move(result.get()));
return Error::Ok;
}

AOTITorchError aoti_torch_delete_cuda_stream_guard(
CUDAStreamGuardHandle guard) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_delete_cuda_stream_guard failed: guard is null");

delete guard;
return Error::Ok;
}

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream) {
ET_CHECK_OR_RETURN_ERROR(
ret_stream != nullptr,
InvalidArgument,
"aoti_torch_get_current_cuda_stream failed: ret_stream is null");

auto result = getCurrentCUDAStream(device_index);
if (!result.ok()) {
return result.error();
}
*ret_stream = static_cast<void*>(result.get());
return Error::Ok;
}

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
104 changes: 104 additions & 0 deletions backends/cuda/runtime/shims/cuda_guard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/guard.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {

using executorch::backends::aoti::AOTITorchError;

extern "C" {

// Handle types for CUDA guards
using CUDAGuardHandle = CUDAGuard*;
using CUDAStreamGuardHandle = CUDAStreamGuard*;

/**
* Creates a CUDA device guard that sets the current device and restores it
* upon destruction.
*
* @param device_index The device index to set as current
* @param ret_guard Output parameter for the created guard handle (must not be
* null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard);

/**
* Deletes a CUDA device guard and frees its associated resources.
*
* @param guard Handle to the guard to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard);

/**
* Sets the CUDA device to a new index for an existing guard.
*
* @param guard Handle to the guard
* @param device_index The device index to set as current
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index);

/**
* Creates a CUDA stream guard that sets the current device and stream,
* restoring both upon destruction.
*
* @param stream The CUDA stream to set as current
* @param device_index The device index for the stream
* @param ret_guard Output parameter for the created guard handle (must not be
* null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard);

/**
* Deletes a CUDA stream guard and frees its associated resources.
*
* @param guard Handle to the stream guard to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);

/**
* Gets the current CUDA stream for a specified device.
*
* @param device_index The device index (-1 to use current device)
* @param ret_stream Output parameter for the current stream (must not be null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream);

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_cpp_unittest("aoti_torch_copy_")
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
Loading
Loading