diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index ce364f2c4b0..845144af50f 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -40,7 +40,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) # Ensure symbols are exported properly target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) -# Link against PyTorch libraries and standard libraries +# Link against ExecuTorch libraries and standard libraries target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) executorch_target_link_options_shared_lib(aoti_common) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 2f9b36e3c4f..abc83779443 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() { } // Dtype constants - these return the PyTorch dtype codes -// Currently only float32 is supported, but using robust enum-based approach int32_t aoti_torch_dtype_float32() { return 6; // PyTorch's float32 dtype code } +int32_t aoti_torch_dtype_bfloat16() { + return 15; // PyTorch's bfloat16 dtype code +} + +int32_t aoti_torch_dtype_int64() { + return 4; // PyTorch's int64 dtype code +} + // Cleanup functions void cleanup_tensor_metadata() { internal::tensor_to_sizes.clear(); diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index ffcbaa11a08..5f54cd1c878 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); int32_t aoti_torch_device_type_cpu(); int32_t aoti_torch_layout_strided(); int32_t aoti_torch_dtype_float32(); +int32_t aoti_torch_dtype_bfloat16(); +int32_t aoti_torch_dtype_int64(); // Autograd mode functions int32_t aoti_torch_grad_mode_is_enabled(); diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index dc5b1b786f8..575f676e4cc 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -55,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic) # Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries target_link_libraries( - aoti_cuda - PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} - # Link PyTorch libraries for AOTI CUDA functions - ${TORCH_LIBRARIES} + aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} ) # If you need other CUDA libraries, link them similarly: # target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)