Skip to content

Commit 5307368

Browse files
authored
Merge branch 'main' into stable_diffusion_lcm
2 parents 3b50602 + bf7d755 commit 5307368

File tree

159 files changed

+5879
-1119
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

159 files changed

+5879
-1119
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mpmath==1.3.0
22
numpy>=2.0.0; python_version >= '3.10'
33
PyYAML==6.0.1
4-
ruamel.yaml==0.17.32
4+
ruamel.yaml==0.18.15
55
sympy==1.12
66
timm==0.6.13
77
tomli==2.0.1

.ci/scripts/test_llama_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ cmake_build_llama_runner
5555
# Constants.
5656
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
5757
PROMPT="What happens if you eat watermelon seeds?"
58-
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
58+
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C and"
5959

6060
# Export LoRA PTE file.
6161
MODEL_NAME="llama_3_2_1B_lora"

.githooks/pre-commit

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ if git diff --cached --name-only | grep -q "^torch_pin.py$"; then
88
echo "📝 Updating PyTorch commit pin..."
99

1010
# Run the update script
11-
if python .github/scripts/update_pytorch_pin.py; then
11+
hook_output=$(python .github/scripts/update_pytorch_pin.py 2>&1)
12+
hook_status=$?
13+
echo "$hook_output"
14+
15+
if [ $hook_status -eq 0 ]; then
1216
# Check if pytorch.txt was modified
1317
if ! git diff --quiet .ci/docker/ci_commit_pins/pytorch.txt; then
1418
echo "✅ PyTorch commit pin updated successfully"
@@ -19,9 +23,14 @@ if git diff --cached --name-only | grep -q "^torch_pin.py$"; then
1923
echo "ℹ️ PyTorch commit pin unchanged"
2024
fi
2125
else
22-
echo "❌ Failed to update PyTorch commit pin"
23-
echo "Please run: python .github/scripts/update_pytorch_pin.py"
24-
exit 1
26+
if echo "$hook_output" | grep -qi "rate limit exceeded"; then
27+
echo "⚠️ PyTorch commit pin not updated due to GitHub API rate limiting."
28+
echo " Please manually update .ci/docker/ci_commit_pins/pytorch.txt if needed."
29+
else
30+
echo "❌ Failed to update PyTorch commit pin"
31+
echo "Please run: python .github/scripts/update_pytorch_pin.py"
32+
exit 1
33+
fi
2534
fi
2635
fi
2736

.github/scripts/update_pytorch_pin.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import re
55
import sys
66
import urllib.request
7-
from datetime import datetime
87

98

109
def parse_nightly_version(nightly_version):
@@ -53,7 +52,7 @@ def get_commit_hash_for_nightly(date_str):
5352
Commit hash string
5453
"""
5554
api_url = "https://api.github.com/repos/pytorch/pytorch/commits"
56-
params = f"?sha=nightly&per_page=100"
55+
params = f"?sha=nightly&per_page=50"
5756
url = api_url + params
5857

5958
req = urllib.request.Request(url)
@@ -74,14 +73,21 @@ def get_commit_hash_for_nightly(date_str):
7473
commit_msg = commit.get("commit", {}).get("message", "")
7574
# Check if the first line of commit message matches
7675
first_line = commit_msg.split("\n")[0].strip()
77-
if first_line == target_title or first_line.startswith(f"{date_str} nightly"):
78-
return commit["sha"]
76+
if first_line.startswith(f"{date_str} nightly"):
77+
return extract_hash_from_title(first_line)
7978

8079
raise ValueError(
8180
f"Could not find commit with title matching '{target_title}' in nightly branch"
8281
)
8382

8483

84+
def extract_hash_from_title(title):
85+
match = re.search(r"\(([0-9a-fA-F]{7,40})\)", title)
86+
if not match:
87+
raise ValueError(f"Could not extract commit hash from title '{title}'")
88+
return match.group(1)
89+
90+
8591
def update_pytorch_pin(commit_hash):
8692
"""
8793
Update .ci/docker/ci_commit_pins/pytorch.txt with the new commit hash.

.github/workflows/pull.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,7 @@ jobs:
892892
# Install test requirements
893893
pip install -r backends/nxp/requirements-tests-pypi.txt
894894
pip install -r backends/nxp/requirements-tests-eiq.txt
895+
PYTHON_EXECUTABLE=python bash examples/nxp/setup.sh
895896
896897
# Run pytest
897898
PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh

backends/aoti/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2626
find_package_torch()
2727

2828
# Common AOTI functionality - combines all AOTI common components
29-
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
29+
set(_aoti_common_sources common_shims.cpp)
3030
add_library(aoti_common STATIC ${_aoti_common_sources})
3131
target_include_directories(
3232
aoti_common

backends/aoti/aoti_model_container.h renamed to backends/aoti/aoti_delegate_handle.h

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,36 +60,17 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
6060
AOTInductorStreamHandle stream_handle,
6161
AOTIProxyExecutorHandle proxy_executor_handle);
6262

63-
// Global function pointers (will be loaded dynamically)
64-
extern AOTInductorModelContainerCreateWithDeviceFunc
65-
AOTInductorModelContainerCreateWithDevice;
66-
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
67-
extern AOTInductorModelContainerGetNumInputsFunc
68-
AOTInductorModelContainerGetNumInputs;
69-
extern AOTInductorModelContainerGetNumOutputsFunc
70-
AOTInductorModelContainerGetNumOutputs;
71-
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
72-
7363
// Retrieves the name of an input tensor by index from the AOTI model container.
74-
// Needed by Metal backend
7564
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
7665
AOTInductorModelContainerHandle container_handle,
7766
size_t input_idx,
7867
const char** input_name);
7968

8069
// Retrieves the number of constants from the AOTI model container.
81-
// Needed by Metal backend
8270
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
8371
AOTInductorModelContainerHandle container_handle,
8472
size_t* num_constants);
8573

86-
// Global function pointers (will be loaded dynamically).
87-
// Needed by Metal backend
88-
extern AOTInductorModelContainerGetInputNameFunc
89-
AOTInductorModelContainerGetInputName;
90-
extern AOTInductorModelContainerGetNumConstantsFunc
91-
AOTInductorModelContainerGetNumConstants;
92-
9374
} // extern "C"
9475

9576
// AOTI Delegate Handle structure
@@ -99,6 +80,13 @@ struct AOTIDelegateHandle {
9980
AOTInductorModelContainerHandle container_handle;
10081
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
10182
// dependency
83+
84+
// Function pointers specific to this handle's shared library
85+
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
86+
AOTInductorModelContainerDeleteFunc delete_container;
87+
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
88+
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
89+
AOTInductorModelContainerRunFunc run;
10290
};
10391

10492
} // namespace aoti

backends/aoti/aoti_model_container.cpp

Lines changed: 0 additions & 39 deletions
This file was deleted.

backends/aoti/passes/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "passes",
7+
srcs = [
8+
"replace_view_copy_with_view.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir:pass_base",
16+
],
17+
)

backends/cuda/replace_slice_copy_with_slice.py renamed to backends/aoti/passes/replace_view_copy_with_view.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-strict
7+
# This pass replaces view_copy ops with view ops. This is different than
8+
# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py
9+
# because this should only be used in the AOTInductor backend, as it
10+
# has less restrictions on whether the tensor memory is densely packed,
811

9-
from typing import Dict, Iterable, Tuple
12+
from typing import Dict, Iterable
1013

1114
import torch
1215
from executorch.exir.dialects._ops import ops
@@ -15,33 +18,30 @@
1518
from torch import fx
1619

1720

18-
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
19-
torch.ops.aten.slice_copy.Tensor,
20-
ops.edge.aten.slice_copy.Tensor,
21-
)
22-
23-
_SLICE_TARGETS: Dict[
21+
_VIEW_TARGETS: Dict[
2422
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
2523
] = {
2624
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
2725
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
26+
torch.ops.aten.select_copy.int: torch.ops.aten.select.int,
27+
ops.edge.aten.select_copy.int: ops.edge.aten.select.int,
2828
}
2929

3030

31-
class ReplaceSliceCopyWithSlicePass(ExportPass):
32-
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
31+
class ReplaceViewCopyWithViewPass(ExportPass):
32+
"""Replace non-mutated ``view_copy`` type of ops with ``view`` ops."""
3333

3434
def call(self, graph_module: fx.GraphModule) -> PassResult:
3535
graph_changed = False
3636

3737
for node in graph_module.graph.nodes:
38-
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
38+
if node.op != "call_function" or node.target not in _VIEW_TARGETS:
3939
continue
4040

4141
if self._has_blocking_user(node, node.users.keys()):
4242
continue
4343

44-
node.target = _SLICE_TARGETS[node.target]
44+
node.target = _VIEW_TARGETS[node.target]
4545
graph_changed = True
4646

4747
if graph_changed:

0 commit comments

Comments
 (0)