Skip to content

Commit b571906

Browse files
committed
Support aoti_torch_cuda__weight_int4pack_mm
Summary: When quantizing a model with 4w_hqq (huggingface/optimum-executorch#164), AOTI-generated code will call aoti_torch_cuda__weight_int4pack_mm as a fallback op. This PR borrows the CUDA implementation of _weight_int4pack_mm_cuda from libtorch, by replacing at::Tensor and relevant utility functions with ET equivalents. Using the Voxtral runner as an example, With the bfloat16 format, here is the generated ptd file size and latency. ``` aoti_cuda_blob.ptd: 9.0 GB Program load latency (ms): 0.054 Method load latency (ms): audio_encoder: 1492.989 token_embedding: 803.561 text_decoder: 6556.770 Run latency (ms): audio_encoder: 76.848 token_embedding: 6.479 text_decoder: 149.128 ``` With `--qlinear 4w_hqq --qlinear_encoder 4w_hqq`, the ptd file size is cut more than half, with slowdowns in the encoder and decoder parts. ``` aoti_cuda_blob.ptd: 3.7 GB Program load latency (ms): 0.051 Method load latency (ms): audio_encoder: 716.667 token_embedding: 633.476 text_decoder: 1840.760 Run latency (ms): audio_encoder: 329.274 token_embedding: 4.285 text_decoder: 335.590 ``` ghstack-source-id: 29b5b16 Pull Request resolved: #15030
1 parent afd98fe commit b571906

File tree

12 files changed

+1883
-5
lines changed

12 files changed

+1883
-5
lines changed

backends/aoti/common_shims.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ int32_t aoti_torch_dtype_bfloat16() {
172172
return 15; // PyTorch's bfloat16 dtype code
173173
}
174174

175+
int32_t aoti_torch_dtype_int8() {
176+
return 1; // PyTorch's int32 dtype code
177+
}
178+
179+
int32_t aoti_torch_dtype_int16() {
180+
return 2; // PyTorch's int32 dtype code
181+
}
182+
183+
int32_t aoti_torch_dtype_int32() {
184+
return 3; // PyTorch's int32 dtype code
185+
}
186+
175187
int32_t aoti_torch_dtype_int64() {
176188
return 4; // PyTorch's int64 dtype code
177189
}

backends/aoti/common_shims.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ int32_t aoti_torch_device_type_cpu();
5959
int32_t aoti_torch_layout_strided();
6060
int32_t aoti_torch_dtype_float32();
6161
int32_t aoti_torch_dtype_bfloat16();
62+
int32_t aoti_torch_dtype_int8();
63+
int32_t aoti_torch_dtype_int16();
64+
int32_t aoti_torch_dtype_int32();
6265
int32_t aoti_torch_dtype_int64();
6366

6467
// Autograd mode functions

backends/aoti/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
3434
// Convert based on known PyTorch dtype codes (without CUDA-specific
3535
// dependency)
3636
switch (dtype) {
37+
case 1: // PyTorch's int8 dtype code
38+
return executorch::aten::ScalarType::Char;
39+
case 2: // PyTorch's int16 dtype code
40+
return executorch::aten::ScalarType::Short;
41+
case 3: // PyTorch's int32 dtype code
42+
return executorch::aten::ScalarType::Int;
3743
case 4: // PyTorch's int64 dtype code
3844
return executorch::aten::ScalarType::Long;
3945
case 6: // PyTorch's float32 dtype code

backends/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ find_package_torch()
3838
set(_aoti_cuda_sources
3939
runtime/cuda_backend.cpp runtime/shims/memory.cpp
4040
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
41-
runtime/shims/cuda_guard.cpp
41+
runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
4242
)
4343
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
4444
target_include_directories(

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from torch.nn.attention import SDPBackend
2929

3030
# exist fallback operators in et namespace;
31-
supported_fallback_kernels: Dict[str, Any] = {}
31+
supported_fallback_kernels: Dict[str, Any] = {
32+
"at::_ops::_weight_int4pack_mm::call": None,
33+
}
3234

3335
# required fallback kernels but not supported
3436
missing_fallback_kernels: Set[str] = set()

backends/cuda/runtime/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
23

34
oncall("executorch")
45

@@ -7,12 +8,15 @@ runtime.cxx_library(
78
srcs = [
89
"guard.cpp",
910
"shims/cuda_guard.cpp",
11+
"shims/int4mm.cu",
1012
"shims/memory.cpp",
1113
"shims/tensor_attribute.cpp",
1214
],
1315
headers = [
1416
"guard.h",
1517
"shims/cuda_guard.h",
18+
"shims/int4mm.cuh",
19+
"shims/int4mm.h",
1620
"shims/memory.h",
1721
"shims/tensor_attribute.h",
1822
"utils.h",
@@ -30,6 +34,10 @@ runtime.cxx_library(
3034
"//executorch/runtime/core/exec_aten:lib",
3135
"//executorch/runtime/platform:platform",
3236
],
37+
nvcc_flags = get_nvcc_arch_args() + [
38+
"-_NVCC_HOST_COMPILER_FLAG_",
39+
"gcc",
40+
],
3341
external_deps = [
3442
("cuda", None, "cuda-lazy"),
3543
],
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cuda.h>
10+
#include <cuda_runtime.h>
11+
12+
#include <executorch/backends/aoti/utils.h>
13+
#include <executorch/backends/cuda/runtime/shims/int4mm.h>
14+
#include <executorch/backends/cuda/runtime/shims/int4mm.cuh>
15+
#include <executorch/runtime/platform/log.h>
16+
17+
namespace executorch::backends::cuda {
18+
#ifdef __cplusplus
19+
extern "C" {
20+
#endif
21+
22+
AOTITorchError aoti_torch_cuda__weight_int4pack_mm(
23+
Tensor* self,
24+
Tensor* mat2,
25+
int64_t qGroupSize,
26+
Tensor* qScaleAndZeros,
27+
Tensor** ret0) {
28+
// Validate input parameters first
29+
// Only check for null pointers here, as the actual validation of tensor
30+
// properties is done in _weight_int4pack_mm_cuda
31+
ET_CHECK_OR_RETURN_ERROR(
32+
self != nullptr,
33+
InvalidArgument,
34+
"aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null");
35+
36+
ET_CHECK_OR_RETURN_ERROR(
37+
mat2 != nullptr,
38+
InvalidArgument,
39+
"aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null");
40+
41+
ET_CHECK_OR_RETURN_ERROR(
42+
qScaleAndZeros != nullptr,
43+
InvalidArgument,
44+
"aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null");
45+
46+
ET_CHECK_OR_RETURN_ERROR(
47+
ret0 != nullptr,
48+
InvalidArgument,
49+
"aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null");
50+
51+
*ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros);
52+
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
53+
return Error::Ok;
54+
}
55+
56+
#ifdef __cplusplus
57+
}
58+
#endif
59+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)