Skip to content

Commit b5d9a41

Browse files
authored
Merge branch 'main' into sync-pt-commit
2 parents 56fb60e + 45336ce commit b5d9a41

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+3879
-1260
lines changed

.github/workflows/lint.yml

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,28 @@ jobs:
143143
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
144144
timeout: 90
145145
script: |
146-
FILES_NEEDS_FORMAT=$(/opt/google-java-format -n \
147-
extension/android/executorch_android/src/main/java/org/pytorch/executorch/*.java \
148-
extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/*.java \
149-
extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/*.java \
150-
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/*.java \
151-
extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/*.java \
152-
extension/benchmark/android/benchmark/app/src/androidTest/java/org/pytorch/minibench/*.java)
146+
FILES_NEEDS_FORMAT=$(find extension/android/executorch_android/src/main/java/org/pytorch/executorch \
147+
extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm \
148+
extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations \
149+
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch \
150+
extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench \
151+
extension/benchmark/android/benchmark/app/src/androidTest/java/org/pytorch/minibench \
152+
-type f -name "*.java" 2>/dev/null | \
153+
xargs -r /opt/google-java-format -n)
154+
153155
if [ -n "$FILES_NEEDS_FORMAT" ]; then
154-
echo "Warning: The following files need formatting. Please use google-java-format."
155-
echo "Use a binary from https://github.com/google/google-java-format/releases/"
156-
echo "For example:"
157-
echo "wget https://github.com/google/google-java-format/releases/download/v1.23.0/google-java-format_linux-x86-64"
158-
echo "chmod +x google-java-format_linux-x86-64"
159-
echo "./google-java-format_linux-x86-64 -i $FILES_NEEDS_FORMAT"
156+
echo "Warning: The following files need formatting:"
157+
echo "$FILES_NEEDS_FORMAT"
158+
echo ""
159+
echo "Please use google-java-format from https://github.com/google/google-java-format/releases/"
160+
echo ""
161+
echo "To fix, run one of these commands:"
162+
echo " # Using xargs (recommended):"
163+
echo " find <paths> -type f -name '*.java' | xargs google-java-format -i"
164+
echo ""
165+
echo " # Or format specific files:"
166+
echo "$FILES_NEEDS_FORMAT" | while IFS= read -r file; do
167+
echo " google-java-format -i \"$file\""
168+
done
160169
exit 1
161170
fi

backends/aoti/common_shims.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ int32_t aoti_torch_dtype_bfloat16() {
172172
return 15; // PyTorch's bfloat16 dtype code
173173
}
174174

175+
int32_t aoti_torch_dtype_int8() {
176+
return 1; // PyTorch's int32 dtype code
177+
}
178+
179+
int32_t aoti_torch_dtype_int16() {
180+
return 2; // PyTorch's int32 dtype code
181+
}
182+
183+
int32_t aoti_torch_dtype_int32() {
184+
return 3; // PyTorch's int32 dtype code
185+
}
186+
175187
int32_t aoti_torch_dtype_int64() {
176188
return 4; // PyTorch's int64 dtype code
177189
}

backends/aoti/common_shims.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ int32_t aoti_torch_device_type_cpu();
5959
int32_t aoti_torch_layout_strided();
6060
int32_t aoti_torch_dtype_float32();
6161
int32_t aoti_torch_dtype_bfloat16();
62+
int32_t aoti_torch_dtype_int8();
63+
int32_t aoti_torch_dtype_int16();
64+
int32_t aoti_torch_dtype_int32();
6265
int32_t aoti_torch_dtype_int64();
6366

6467
// Dtype utility function needed by Metal backend

backends/aoti/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
3434
// Convert based on known PyTorch dtype codes (without CUDA-specific
3535
// dependency)
3636
switch (dtype) {
37+
case 1: // PyTorch's int8 dtype code
38+
return executorch::aten::ScalarType::Char;
39+
case 2: // PyTorch's int16 dtype code
40+
return executorch::aten::ScalarType::Short;
41+
case 3: // PyTorch's int32 dtype code
42+
return executorch::aten::ScalarType::Int;
3743
case 4: // PyTorch's int64 dtype code
3844
return executorch::aten::ScalarType::Long;
3945
case 6: // PyTorch's float32 dtype code

backends/arm/test/misc/test_debug_feats.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TosaPipelineFP,
2222
TosaPipelineINT,
2323
)
24+
from executorch.backends.test.harness.stages import StageType
2425

2526

2627
input_t1 = Tuple[torch.Tensor] # Input x
@@ -104,7 +105,7 @@ def test_INT_artifact(test_data: input_t1):
104105

105106
@common.parametrize("test_data", Linear.inputs)
106107
def test_numerical_diff_print(test_data: input_t1):
107-
pipeline = TosaPipelineFP[input_t1](
108+
pipeline = TosaPipelineINT[input_t1](
108109
Linear(),
109110
test_data,
110111
[],
@@ -119,7 +120,9 @@ def test_numerical_diff_print(test_data: input_t1):
119120
# not present.
120121
try:
121122
# Tolerate 0 difference => we want to trigger a numerical diff
122-
tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0)
123+
tester.run_method_and_compare_outputs(
124+
stage=StageType.INITIAL_MODEL, atol=0, rtol=0, qtol=0
125+
)
123126
except AssertionError:
124127
pass # Implicit pass test
125128
else:

backends/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ find_package_torch()
3838
set(_aoti_cuda_sources
3939
runtime/cuda_backend.cpp runtime/shims/memory.cpp
4040
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
41-
runtime/shims/cuda_guard.cpp
41+
runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
4242
)
4343
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
4444
target_include_directories(

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
}
3434

3535
# exist fallback operators in et namespace;
36-
supported_fallback_kernels: Dict[str, Any] = {}
36+
supported_fallback_kernels: Dict[str, Any] = {
37+
"at::_ops::_weight_int4pack_mm::call": None,
38+
}
3739

3840
# required fallback kernels but not supported
3941
missing_fallback_kernels: Set[str] = set()

backends/cuda/runtime/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
23

34
oncall("executorch")
45

@@ -7,12 +8,15 @@ runtime.cxx_library(
78
srcs = [
89
"guard.cpp",
910
"shims/cuda_guard.cpp",
11+
"shims/int4mm.cu",
1012
"shims/memory.cpp",
1113
"shims/tensor_attribute.cpp",
1214
],
1315
headers = [
1416
"guard.h",
1517
"shims/cuda_guard.h",
18+
"shims/int4mm.cuh",
19+
"shims/int4mm.h",
1620
"shims/memory.h",
1721
"shims/tensor_attribute.h",
1822
"utils.h",
@@ -30,6 +34,10 @@ runtime.cxx_library(
3034
"//executorch/runtime/core/exec_aten:lib",
3135
"//executorch/runtime/platform:platform",
3236
],
37+
nvcc_flags = get_nvcc_arch_args() + [
38+
"-_NVCC_HOST_COMPILER_FLAG_",
39+
"gcc",
40+
],
3341
external_deps = [
3442
("cuda", None, "cuda-lazy"),
3543
],
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cuda.h>
10+
#include <cuda_runtime.h>
11+
12+
#include <executorch/backends/aoti/utils.h>
13+
#include <executorch/backends/cuda/runtime/shims/int4mm.h>
14+
#include <executorch/backends/cuda/runtime/shims/int4mm.cuh>
15+
#include <executorch/runtime/platform/log.h>
16+
17+
namespace executorch::backends::cuda {
18+
#ifdef __cplusplus
19+
extern "C" {
20+
#endif
21+
22+
AOTITorchError aoti_torch_cuda__weight_int4pack_mm(
23+
Tensor* self,
24+
Tensor* mat2,
25+
int64_t qGroupSize,
26+
Tensor* qScaleAndZeros,
27+
Tensor** ret0) {
28+
// Validate input parameters first
29+
// Only check for null pointers here, as the actual validation of tensor
30+
// properties is done in _weight_int4pack_mm_cuda
31+
ET_CHECK_OR_RETURN_ERROR(
32+
self != nullptr,
33+
InvalidArgument,
34+
"aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null");
35+
36+
ET_CHECK_OR_RETURN_ERROR(
37+
mat2 != nullptr,
38+
InvalidArgument,
39+
"aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null");
40+
41+
ET_CHECK_OR_RETURN_ERROR(
42+
qScaleAndZeros != nullptr,
43+
InvalidArgument,
44+
"aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null");
45+
46+
ET_CHECK_OR_RETURN_ERROR(
47+
ret0 != nullptr,
48+
InvalidArgument,
49+
"aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null");
50+
51+
*ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros);
52+
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
53+
return Error::Ok;
54+
}
55+
56+
#ifdef __cplusplus
57+
}
58+
#endif
59+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)