File tree Expand file tree Collapse file tree 2 files changed +76
-0
lines changed
backends/cuda/runtime/shims Expand file tree Collapse file tree 2 files changed +76
-0
lines changed Original file line number Diff line number Diff line change 1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+
9+ #include < executorch/backends/cuda/runtime/shims/tensor_attribute.h>
10+
11+ namespace executorch {
12+ namespace backends {
13+ namespace cuda {
14+
15+ extern " C" {
16+
17+ // Device type functions for tensor attributes
18+ AOTITorchError aoti_torch_get_device_type (
19+ Tensor* tensor,
20+ int32_t * ret_device_type) {
21+ // All tensors in aoti-cuda delegate are on CUDA
22+ *ret_device_type = aoti_torch_device_type_cuda ();
23+ return Error::Ok;
24+ }
25+
26+ // Device type constants
27+ int32_t aoti_torch_device_type_cuda () {
28+ // Let's say cuda is 1 for ET as well
29+ return 1 ;
30+ }
31+
32+ } // extern "C"
33+
34+ } // namespace cuda
35+ } // namespace backends
36+ } // namespace executorch
Original file line number Diff line number Diff line change 1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+
9+ #pragma once
10+
11+ #include < executorch/extension/tensor/tensor.h>
12+ #include < executorch/runtime/core/error.h>
13+ #include < cstdint>
14+
15+ namespace executorch {
16+ namespace backends {
17+ namespace cuda {
18+
19+ // Common using declarations for ExecutorTorch types
20+ using executorch::runtime::Error;
21+ using executorch::runtime::etensor::Tensor;
22+
23+ extern " C" {
24+
25+ // Common AOTI type aliases
26+ using AOTITorchError = Error;
27+
28+ // Device type functions for tensor attributes
29+ AOTITorchError aoti_torch_get_device_type (
30+ Tensor* tensor,
31+ int32_t * ret_device_type);
32+
33+ // Device type constants
34+ int32_t aoti_torch_device_type_cuda ();
35+
36+ } // extern "C"
37+
38+ } // namespace cuda
39+ } // namespace backends
40+ } // namespace executorch
You can’t perform that action at this time.
0 commit comments