Skip to content

Commit 9c67c64

Browse files
authored
Merge branch 'main' into cuda-err-msg
2 parents cf8c7b0 + 2c706f1 commit 9c67c64

File tree

19 files changed

+348
-85
lines changed

19 files changed

+348
-85
lines changed

.githooks/pre-commit

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ if git diff --cached --name-only | grep -q "^torch_pin.py$"; then
88
echo "📝 Updating PyTorch commit pin..."
99

1010
# Run the update script
11-
if python .github/scripts/update_pytorch_pin.py; then
11+
hook_output=$(python .github/scripts/update_pytorch_pin.py 2>&1)
12+
hook_status=$?
13+
echo "$hook_output"
14+
15+
if [ $hook_status -eq 0 ]; then
1216
# Check if pytorch.txt was modified
1317
if ! git diff --quiet .ci/docker/ci_commit_pins/pytorch.txt; then
1418
echo "✅ PyTorch commit pin updated successfully"
@@ -19,9 +23,14 @@ if git diff --cached --name-only | grep -q "^torch_pin.py$"; then
1923
echo "ℹ️ PyTorch commit pin unchanged"
2024
fi
2125
else
22-
echo "❌ Failed to update PyTorch commit pin"
23-
echo "Please run: python .github/scripts/update_pytorch_pin.py"
24-
exit 1
26+
if echo "$hook_output" | grep -qi "rate limit exceeded"; then
27+
echo "⚠️ PyTorch commit pin not updated due to GitHub API rate limiting."
28+
echo " Please manually update .ci/docker/ci_commit_pins/pytorch.txt if needed."
29+
else
30+
echo "❌ Failed to update PyTorch commit pin"
31+
echo "Please run: python .github/scripts/update_pytorch_pin.py"
32+
exit 1
33+
fi
2534
fi
2635
fi
2736

.github/scripts/update_pytorch_pin.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import re
55
import sys
66
import urllib.request
7-
from datetime import datetime
87

98

109
def parse_nightly_version(nightly_version):
@@ -53,7 +52,7 @@ def get_commit_hash_for_nightly(date_str):
5352
Commit hash string
5453
"""
5554
api_url = "https://api.github.com/repos/pytorch/pytorch/commits"
56-
params = f"?sha=nightly&per_page=100"
55+
params = f"?sha=nightly&per_page=50"
5756
url = api_url + params
5857

5958
req = urllib.request.Request(url)
@@ -74,14 +73,21 @@ def get_commit_hash_for_nightly(date_str):
7473
commit_msg = commit.get("commit", {}).get("message", "")
7574
# Check if the first line of commit message matches
7675
first_line = commit_msg.split("\n")[0].strip()
77-
if first_line == target_title or first_line.startswith(f"{date_str} nightly"):
78-
return commit["sha"]
76+
if first_line.startswith(f"{date_str} nightly"):
77+
return extract_hash_from_title(first_line)
7978

8079
raise ValueError(
8180
f"Could not find commit with title matching '{target_title}' in nightly branch"
8281
)
8382

8483

84+
def extract_hash_from_title(title):
85+
match = re.search(r"\(([0-9a-fA-F]{7,40})\)", title)
86+
if not match:
87+
raise ValueError(f"Could not extract commit hash from title '{title}'")
88+
return match.group(1)
89+
90+
8591
def update_pytorch_pin(commit_hash):
8692
"""
8793
Update .ci/docker/ci_commit_pins/pytorch.txt with the new commit hash.

backends/aoti/utils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,64 @@ inline bool is_tensor_contiguous(
100100

101101
} // extern "C"
102102

103+
// Utility function to convert sizes pointer to vector
104+
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
105+
int64_t ndim,
106+
const int64_t* sizes_ptr) {
107+
std::vector<executorch::aten::SizesType> sizes(ndim);
108+
for (int i = 0; i < ndim; i++) {
109+
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
110+
}
111+
return sizes;
112+
}
113+
114+
// Utility function to convert strides pointer to vector or calculate from sizes
115+
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
116+
int64_t ndim,
117+
const int64_t* sizes_ptr,
118+
const int64_t* strides_ptr) {
119+
std::vector<executorch::aten::StridesType> strides(ndim);
120+
121+
if (strides_ptr != nullptr) {
122+
// Use provided strides.
123+
for (int64_t i = 0; i < ndim; i++) {
124+
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
125+
}
126+
} else {
127+
// Calculate strides from sizes.
128+
if (ndim > 0) {
129+
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
130+
1); // Last dimension has stride 1
131+
for (int64_t i = ndim - 2; i >= 0; i--) {
132+
if (sizes_ptr[i + 1] == 0) {
133+
strides[i] = strides[i + 1]; // Copy stride when size is 0
134+
} else {
135+
strides[i] = static_cast<executorch::aten::StridesType>(
136+
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
137+
}
138+
}
139+
}
140+
}
141+
return strides;
142+
}
143+
144+
// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
145+
// Contiguous format means strides decrease from left to right:
146+
// For NCHW: strides = [C*H*W, H*W, W, 1]
147+
inline bool is_contiguous_tensor(
148+
std::vector<executorch::aten::SizesType>& sizes,
149+
std::vector<executorch::aten::StridesType>& strides) {
150+
int64_t ndim = static_cast<int64_t>(strides.size());
151+
int64_t expected_stride = 1;
152+
for (int64_t i = ndim - 1; i >= 0; i--) {
153+
if (strides[i] != expected_stride) {
154+
return false;
155+
}
156+
expected_stride *= sizes[i];
157+
}
158+
return true;
159+
}
160+
103161
} // namespace aoti
104162
} // namespace backends
105163
} // namespace executorch
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
10+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
11+
#include <iostream>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Metal-specific device type constant
20+
__attribute__((__visibility__("default"))) int32_t
21+
aoti_torch_device_type_mps() {
22+
return 13; // Consistent with c10/core/DeviceType.h
23+
}
24+
25+
// Override aoti_torch_get_device_type to return MPS device type
26+
AOTITorchError aoti_torch_get_device_type(
27+
AOTITensorHandle tensor,
28+
int32_t* ret_device_type) {
29+
*ret_device_type = aoti_torch_device_type_mps();
30+
return Error::Ok;
31+
}
32+
33+
} // extern "C"
34+
35+
} // namespace metal
36+
} // namespace backends
37+
} // namespace executorch
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/common_shims.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace metal {
17+
18+
extern "C" {
19+
20+
// Metal-specific device type function
21+
int32_t aoti_torch_device_type_mps();
22+
23+
// Override aoti_torch_get_device_type to return MPS device type
24+
AOTITorchError aoti_torch_get_device_type(
25+
AOTITensorHandle tensor,
26+
int32_t* ret_device_type);
27+
28+
} // extern "C"
29+
30+
} // namespace metal
31+
} // namespace backends
32+
} // namespace executorch
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/error.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace metal {
18+
19+
// Common using declarations for ExecutorTorch types
20+
using executorch::runtime::Error;
21+
using executorch::runtime::etensor::Tensor;
22+
23+
extern "C" {
24+
25+
// Common AOTI type aliases
26+
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
27+
using AOTITensorHandle = Tensor*;
28+
using AOTIRuntimeError = Error;
29+
using AOTITorchError = Error;
30+
31+
} // extern "C"
32+
33+
} // namespace metal
34+
} // namespace backends
35+
} // namespace executorch
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Helper function to check if a dtype is supported in Metal backend
20+
bool is_dtype_supported_in_et_metal(int32_t dtype) {
21+
switch (dtype) {
22+
case static_cast<int32_t>(SupportedDTypes::INT64):
23+
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
24+
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
25+
return true;
26+
default:
27+
return false;
28+
}
29+
}
30+
31+
// Metal-specific dtype validation utility function
32+
AOTITorchError validate_dtype(int32_t dtype) {
33+
if (is_dtype_supported_in_et_metal(dtype)) {
34+
return Error::Ok;
35+
}
36+
37+
ET_LOG(
38+
Error,
39+
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
40+
dtype,
41+
static_cast<int32_t>(SupportedDTypes::INT64),
42+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
43+
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
44+
return Error::InvalidArgument;
45+
}
46+
47+
} // extern "C"
48+
49+
} // namespace metal
50+
} // namespace backends
51+
} // namespace executorch
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/utils.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <cstdint>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace metal {
19+
20+
// Enum for supported data types in et-metal backend
21+
enum class SupportedDTypes : int32_t {
22+
// UINT8 = 0, // PyTorch's uint8 dtype code
23+
// INT8 = 1, // PyTorch's int8 dtype code
24+
// INT16 = 2, // PyTorch's int16 dtype code
25+
// INT32 = 3, // PyTorch's int32 dtype code
26+
INT64 = 4, // PyTorch's int64 dtype code
27+
// FLOAT16 = 5, // PyTorch's float16 dtype code
28+
FLOAT32 = 6, // PyTorch's float32 dtype code
29+
// FLOAT64 = 7, // PyTorch's float64 dtype code
30+
// BOOL = 11, // PyTorch's bool dtype code
31+
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
32+
};
33+
34+
extern "C" {
35+
36+
// Helper function to check if a dtype is supported in Metal backend
37+
bool is_dtype_supported_in_et_metal(int32_t dtype);
38+
39+
// Metal-specific dtype validation utility function
40+
AOTITorchError validate_dtype(int32_t dtype);
41+
42+
} // extern "C"
43+
44+
} // namespace metal
45+
} // namespace backends
46+
} // namespace executorch

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
ExecutorchProgramManager,
3939
)
4040
from executorch.exir.passes import ToOutVarPass
41-
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
41+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
4242
from executorch.exir.program._program import to_edge
4343

4444
from torch.export.exported_program import ExportedProgram
@@ -460,7 +460,7 @@ def _lower_ep_to_cadence_gen_etrecord(
460460
emit_stacktrace=False,
461461
to_out_var_pass=ToOutVarPass(),
462462
extract_delegate_segments=False,
463-
sym_shape_eval_pass=HintBasedSymShapeEvalPass(),
463+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
464464
),
465465
)
466466

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,16 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
342342
quantizers = get_cadence_default_quantizers()
343343
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
344344
super().__init__(quantizers)
345+
346+
347+
class CadenceWith16BitLinearActivationsQuantizer(CadenceQuantizer):
348+
"""
349+
Quantizer including A16 fully_connected
350+
"""
351+
352+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
353+
if quantizers is None:
354+
quantizers = []
355+
# Add 16-bit quantizers for LinearPattern
356+
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
357+
super().__init__(quantizers)

0 commit comments

Comments
 (0)