Skip to content

Commit 769684a

Browse files
authored
tensor attribute shim layers
Differential Revision: D83012802 Pull Request resolved: #14546
1 parent fabbda6 commit 769684a

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace cuda {
14+
15+
extern "C" {
16+
17+
// Device type functions for tensor attributes
18+
AOTITorchError aoti_torch_get_device_type(
19+
Tensor* tensor,
20+
int32_t* ret_device_type) {
21+
// All tensors in aoti-cuda delegate are on CUDA
22+
*ret_device_type = aoti_torch_device_type_cuda();
23+
return Error::Ok;
24+
}
25+
26+
// Device type constants
27+
int32_t aoti_torch_device_type_cuda() {
28+
// Let's say cuda is 1 for ET as well
29+
return 1;
30+
}
31+
32+
} // extern "C"
33+
34+
} // namespace cuda
35+
} // namespace backends
36+
} // namespace executorch
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/error.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace cuda {
18+
19+
// Common using declarations for ExecutorTorch types
20+
using executorch::runtime::Error;
21+
using executorch::runtime::etensor::Tensor;
22+
23+
extern "C" {
24+
25+
// Common AOTI type aliases
26+
using AOTITorchError = Error;
27+
28+
// Device type functions for tensor attributes
29+
AOTITorchError aoti_torch_get_device_type(
30+
Tensor* tensor,
31+
int32_t* ret_device_type);
32+
33+
// Device type constants
34+
int32_t aoti_torch_device_type_cuda();
35+
36+
} // extern "C"
37+
38+
} // namespace cuda
39+
} // namespace backends
40+
} // namespace executorch

0 commit comments

Comments
 (0)