File tree Expand file tree Collapse file tree 2 files changed +76
-0
lines changed
backends/cuda/runtime/shims Expand file tree Collapse file tree 2 files changed +76
-0
lines changed Original file line number Diff line number Diff line change
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include < executorch/backends/cuda/runtime/shims/tensor_attribute.h>
10
+
11
+ namespace executorch {
12
+ namespace backends {
13
+ namespace cuda {
14
+
15
+ extern " C" {
16
+
17
+ // Device type functions for tensor attributes
18
+ AOTITorchError aoti_torch_get_device_type (
19
+ Tensor* tensor,
20
+ int32_t * ret_device_type) {
21
+ // All tensors in aoti-cuda delegate are on CUDA
22
+ *ret_device_type = aoti_torch_device_type_cuda ();
23
+ return Error::Ok;
24
+ }
25
+
26
+ // Device type constants
27
+ int32_t aoti_torch_device_type_cuda () {
28
+ // Let's say cuda is 1 for ET as well
29
+ return 1 ;
30
+ }
31
+
32
+ } // extern "C"
33
+
34
+ } // namespace cuda
35
+ } // namespace backends
36
+ } // namespace executorch
Original file line number Diff line number Diff line change
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include < executorch/extension/tensor/tensor.h>
12
+ #include < executorch/runtime/core/error.h>
13
+ #include < cstdint>
14
+
15
+ namespace executorch {
16
+ namespace backends {
17
+ namespace cuda {
18
+
19
+ // Common using declarations for ExecutorTorch types
20
+ using executorch::runtime::Error;
21
+ using executorch::runtime::etensor::Tensor;
22
+
23
+ extern " C" {
24
+
25
+ // Common AOTI type aliases
26
+ using AOTITorchError = Error;
27
+
28
+ // Device type functions for tensor attributes
29
+ AOTITorchError aoti_torch_get_device_type (
30
+ Tensor* tensor,
31
+ int32_t * ret_device_type);
32
+
33
+ // Device type constants
34
+ int32_t aoti_torch_device_type_cuda ();
35
+
36
+ } // extern "C"
37
+
38
+ } // namespace cuda
39
+ } // namespace backends
40
+ } // namespace executorch
You can’t perform that action at this time.
0 commit comments