Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
# Ensure symbols are exported properly
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)

# Link against PyTorch libraries and standard libraries
# Link against ExecuTorch libraries and standard libraries
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(aoti_common)

Expand Down
9 changes: 8 additions & 1 deletion backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() {
}

// Dtype constants - these return the PyTorch dtype codes
// Currently only float32 is supported, but using robust enum-based approach
int32_t aoti_torch_dtype_float32() {
return 6; // PyTorch's float32 dtype code
}

int32_t aoti_torch_dtype_bfloat16() {
return 15; // PyTorch's bfloat16 dtype code
}

int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}

// Cleanup functions
void cleanup_tensor_metadata() {
internal::tensor_to_sizes.clear();
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
int32_t aoti_torch_device_type_cpu();
int32_t aoti_torch_layout_strided();
int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int64();

// Autograd mode functions
int32_t aoti_torch_grad_mode_is_enabled();
Expand Down
5 changes: 1 addition & 4 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic)

# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
target_link_libraries(
aoti_cuda
PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI CUDA functions
${TORCH_LIBRARIES}
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
)
# If you need other CUDA libraries, link them similarly:
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)
Expand Down
Loading