Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ int32_t aoti_torch_dtype_bfloat16() {
return 15; // PyTorch's bfloat16 dtype code
}

int32_t aoti_torch_dtype_int32() {
return 3; // PyTorch's int32 dtype code
}

int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ int32_t aoti_torch_device_type_cpu();
int32_t aoti_torch_layout_strided();
int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int32();
int32_t aoti_torch_dtype_int64();

// Autograd mode functions
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
// Convert based on known PyTorch dtype codes (without CUDA-specific
// dependency)
switch (dtype) {
case 3: // PyTorch's int32 dtype code
return executorch::aten::ScalarType::Int;
case 4: // PyTorch's int64 dtype code
return executorch::aten::ScalarType::Long;
case 6: // PyTorch's float32 dtype code
Expand Down
14 changes: 13 additions & 1 deletion backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,20 @@ find_package_torch()
set(_aoti_cuda_sources
runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
runtime/shims/cuda_guard.cpp
runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
)
# Set default CUDA architectures if not specified (int4mm requires sm_80+)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES
"80;86;89;90"
CACHE STRING "CUDA architectures"
)
message(
STATUS
"CMAKE_CUDA_ARCHITECTURES not set, using default: 80;86;89;90 (Ampere+)"
)
message(STATUS " Override with: cmake -DCMAKE_CUDA_ARCHITECTURES=<arch> ...")
endif()
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
aoti_cuda
Expand Down
4 changes: 3 additions & 1 deletion backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from torch.nn.attention import SDPBackend

# exist fallback operators in et namespace;
supported_fallback_kernels: Dict[str, Any] = {}
supported_fallback_kernels: Dict[str, Any] = {
"at::_ops::_weight_int4pack_mm::call": None,
}

# required fallback kernels but not supported
missing_fallback_kernels: Set[str] = set()
Expand Down
3 changes: 3 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ runtime.cxx_library(
srcs = [
"guard.cpp",
"shims/cuda_guard.cpp",
"shims/int4mm.cu",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/cuda_guard.h",
"shims/int4mm.cuh",
"shims/int4mm.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
57 changes: 57 additions & 0 deletions backends/cuda/runtime/shims/int4mm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cuda.h>
#include <cuda_runtime.h>

#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/int4mm.h>
#include <executorch/backends/cuda/runtime/shims/int4mm.cuh>
#include <executorch/runtime/platform/log.h>

namespace executorch::backends::cuda {
#ifdef __cplusplus
extern "C" {
#endif

AOTITorchError aoti_torch_cuda__weight_int4pack_mm(
Tensor* self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should check whether self is bfloat16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have quite a few tensor checking in the actual _weight_int4pack_mm_cuda function, so we don't have to do repeat them here?

Tensor* mat2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check whether mat2 is int32

int64_t qGroupSize,
Tensor* qScaleAndZeros,
Tensor** ret0) {
// Validate input parameters first
ET_CHECK_OR_RETURN_ERROR(
self != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null");

ET_CHECK_OR_RETURN_ERROR(
mat2 != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null");

ET_CHECK_OR_RETURN_ERROR(
qScaleAndZeros != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null");

ET_CHECK_OR_RETURN_ERROR(
ret0 != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ET_CHECK_OR_RETURN_ERROR(
        qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || qGroupSize == 256,
        InvalidArgument,
        "aoti_torch_cuda__weight_int4pack_mm: qGroupSize must be 32/64/128/256, got %lld",
        static_cast<long long>(qGroupSize));

*ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros);
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
return Error::Ok;
}

#ifdef __cplusplus
}
#endif
} // namespace executorch::backends::cuda
Loading
Loading