Skip to content

Commit b157fc1

Browse files
authored
Merge branch 'pytorch:main' into main
2 parents 59fdd72 + 914d5ff commit b157fc1

File tree

43 files changed

+2859
-521
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2859
-521
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
732b11313b2006b4d8649500eaf5567ec6ac1e49
1+
f8aa919593cc51301ade73a2ee5491582521ab80

.github/workflows/add-unanswered-to-project.yml

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,17 @@ jobs:
4242
"agunapal", "SamGondelman", "Ninja91", "ivayloen", "DrJessop", "rodrigos01meta", "akrieger", "cmt0", "yiming0416",
4343
"ethansfng", "ThomasJannaud", "nirvanagth", "marcinkwiatkowski", "3l1", "omerjerk", "nitish2112", "yipjustin",
4444
"ejnguyen", "andrewor14", "phaiting", "mgiordy", "LeeOHzzZ", "adicatana", "Polyomino", "ezrilow", "navsud",
45-
"michaelmaitland", "RahulC7", "seyeong-han", "thdusdl1219", "jaejunku", "felixweilbach", "apullin", "trviv", "YifanShenSZ",
46-
"RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat", "azad-meta", "junpi", "pytorchbot", "pytorchmergebot", "pytorchupdatebot",
47-
"facebook-github-bot", "app/dependabot", "Erik-Lundell", "zingo", "AdrianLundell", "oscarandersson8218", "per", "Sebastian-Larsson",
48-
"SaoirseARM", "robell", "mansnils", "martinlsm", "freddan80", "YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov",
49-
"fumchin", "wwwind", "benkli01", "Tessil", "maddun01", "Michiel-Olieslagers", "armwaheed", "agrima1304", "emmakujala",
50-
"annietllnd", "MatthiasHertel80", "AlexTawseArm", "jmahbs", "morgolock", "Christoffer-JL", "ArmRyan", "xingguo01",
51-
"tgonzalezorlandoarm", "chizkiyahu", "sarah-blades", "haowhsu-quic", "shewu-quic", "winskuo-quic", "chunit-quic", "DannyYuyang-quic",
52-
"chuntl", "thchenqti", "jethroqti", "chenweng-quic", "cymbalrush", "DenisVieriu97", "billmguo", "StrycekSimon", "jirioc", "robert-kalmar",
53-
"skywall", "MartinPavella", "roman-janik-nxp", "novak-vaclav", "neuropilot-captain", "dijopaul", "cad-rlc", "cad-audio",
54-
"ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", "cavusmustafa", "anzr299", "Jiseong-oh", "alexdean08",
45+
"michaelmaitland", "RahulC7", "seyeong-han", "thdusdl1219", "jaejunku", "felixweilbach", "apullin", "trviv", "junluan01",
46+
"YifanShenSZ", "RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat", "azad-meta", "junpi", "pytorchbot", "pytorchmergebot",
47+
"pytorchupdatebot", "facebook-github-bot", "app/dependabot", "Erik-Lundell", "zingo", "AdrianLundell", "oscarandersson8218",
48+
"per", "Sebastian-Larsson", "SaoirseARM", "robell", "mansnils", "martinlsm", "freddan80", "YufengShi-dudu", "tom-arm",
49+
"perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind", "benkli01", "Tessil", "maddun01", "Michiel-Olieslagers", "armwaheed",
50+
"agrima1304", "emmakujala", "annietllnd", "MatthiasHertel80", "AlexTawseArm", "jmahbs", "morgolock", "Christoffer-JL",
51+
"ArmRyan", "xingguo01", "tgonzalezorlandoarm", "chizkiyahu", "sarah-blades", "haowhsu-quic", "shewu-quic", "winskuo-quic",
52+
"chunit-quic", "DannyYuyang-quic", "chuntl", "thchenqti", "jethroqti", "chenweng-quic", "cymbalrush", "DenisVieriu97",
53+
"billmguo", "StrycekSimon", "jirioc", "robert-kalmar", "skywall", "MartinPavella", "roman-janik-nxp", "novak-vaclav",
54+
"neuropilot-captain", "dijopaul", "cad-rlc", "cad-audio", "ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", "cavusmustafa",
55+
"anzr299", "Jiseong-oh", "alexdean08",
5556
// explicitly include the dependabot bot login seen in PRs
5657
"dependabot[bot]"
5758
]);
@@ -61,6 +62,9 @@ jobs:
6162
"meta", "facebook", "pytorch", "arm", "apple", "qualcomm", "nxp", "mediatek", "cadence", "intel", "samsung"
6263
]);
6364
65+
// Labels on PRs to exclude from being added to the project
66+
const excludedPrLabels = new Set(["fb-exported", "meta-exported"]);
67+
6468
// Simple cache for user -> boolean (member of excluded org)
6569
const orgsCache = new Map();
6670
@@ -93,6 +97,11 @@ jobs:
9397
return false;
9498
}
9599
100+
function hasExcludedLabel(item) {
101+
if (!item || !item.labels) return false;
102+
return item.labels.some(l => l && l.name && excludedPrLabels.has(l.name.toLowerCase()));
103+
}
104+
96105
async function addItem(contentId, type, number) {
97106
try {
98107
await github.graphql(`

.github/workflows/doc-build.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ on:
1313
schedule:
1414
- cron: '0 0 * * *'
1515

16+
concurrency:
17+
group: docs-${{ github.workflow }}-${{ github.ref }}
18+
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
19+
1620
jobs:
1721
build:
1822
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/aoti/slim/c10/cuda/Exception.h

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,55 @@
88

99
#pragma once
1010

11-
#ifdef CUDA_AVAILABLE
12-
1311
#include <cuda.h>
1412
#include <cuda_runtime.h>
1513

1614
#include <executorch/backends/aoti/slim/c10/macros/Macros.h>
15+
#include <executorch/runtime/core/error.h>
1716
#include <executorch/runtime/platform/assert.h>
1817
#include <executorch/runtime/platform/log.h>
1918

2019
/// Checks a CUDA expression and aborts on error.
2120
/// @param EXPR The CUDA expression to check.
22-
#define ET_CUDA_CHECK(EXPR) \
23-
do { \
24-
const cudaError_t __err = EXPR; \
25-
ET_CHECK_MSG( \
26-
__err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \
21+
#ifndef ET_CUDA_CHECK
22+
#define ET_CUDA_CHECK(EXPR) \
23+
do { \
24+
const cudaError_t __err = EXPR; \
25+
if (__err == cudaSuccess) { \
26+
break; \
27+
} \
28+
ET_LOG( \
29+
Error, \
30+
"%s:%d CUDA error: %s", \
31+
__FILE__, \
32+
__LINE__, \
33+
cudaGetErrorString(__err)); \
34+
ET_CHECK_MSG(false, "CUDA error: %s", cudaGetErrorString(__err)); \
2735
} while (0)
36+
#endif
37+
38+
/// Checks a CUDA expression and returns Error::Internal on failure.
39+
/// @param EXPR The CUDA expression to check.
40+
#ifndef ET_CUDA_CHECK_OR_RETURN_ERROR
41+
#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \
42+
do { \
43+
const cudaError_t __err = EXPR; \
44+
if (__err == cudaSuccess) { \
45+
break; \
46+
} \
47+
ET_LOG( \
48+
Error, \
49+
"%s:%d CUDA error: %s", \
50+
__FILE__, \
51+
__LINE__, \
52+
cudaGetErrorString(__err)); \
53+
return ::executorch::runtime::Error::Internal; \
54+
} while (0)
55+
#endif
2856

2957
/// Checks a CUDA expression and logs a warning on error (non-fatal).
3058
/// @param EXPR The CUDA expression to check.
59+
#ifndef ET_CUDA_LOG_WARN
3160
#define ET_CUDA_LOG_WARN(EXPR) \
3261
do { \
3362
const cudaError_t __err = EXPR; \
@@ -36,5 +65,17 @@
3665
ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \
3766
} \
3867
} while (0)
68+
#endif
69+
70+
/// Kernel launch check macro (with return) - checks cudaGetLastError after
71+
/// kernel launch.
72+
#ifndef ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR
73+
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
74+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())
75+
#endif
3976

40-
#endif // CUDA_AVAILABLE
77+
/// Kernel launch check macro (without return) - checks cudaGetLastError after
78+
/// kernel launch.
79+
#ifndef ET_CUDA_KERNEL_LAUNCH_CHECK
80+
#define ET_CUDA_KERNEL_LAUNCH_CHECK() ET_CUDA_CHECK(cudaGetLastError())
81+
#endif

backends/aoti/slim/core/Storage.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#ifdef CUDA_AVAILABLE
1414
#include <executorch/backends/aoti/slim/c10/cuda/Exception.h>
15-
#include <executorch/backends/cuda/runtime/guard.h>
15+
#include <executorch/backends/aoti/slim/cuda/guard.h>
1616
#endif
1717

1818
#include <executorch/backends/aoti/slim/c10/core/Device.h>

backends/aoti/slim/core/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def define_common_targets():
1818
"//executorch/backends/aoti/slim/util:size_util",
1919
"//executorch/runtime/platform:platform",
2020
"//executorch/backends/aoti/slim/c10/cuda:exception",
21-
"//executorch/backends/cuda/runtime:guard",
21+
"//executorch/backends/aoti/slim/cuda:guard",
2222
],
2323
)
2424

backends/aoti/slim/cuda/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()

backends/aoti/slim/cuda/guard.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/slim/cuda/guard.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <limits>
12+
#include <unordered_map>
13+
14+
namespace executorch::backends::cuda {
15+
16+
namespace {
17+
// Thread-local stream storage (private to this file)
18+
thread_local std::unordered_map<DeviceIndex, cudaStream_t> current_streams_;
19+
} // namespace
20+
21+
Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) {
22+
if (device_index == -1) {
23+
// Get current device if not specified
24+
// CUDA API returns int, explicit cast to DeviceIndex (int8_t) following
25+
// ATen
26+
int tmp_device = -1;
27+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device));
28+
device_index = static_cast<DeviceIndex>(tmp_device);
29+
}
30+
31+
current_streams_[device_index] = stream;
32+
return Error::Ok;
33+
}
34+
35+
Result<cudaStream_t> getCurrentCUDAStream(DeviceIndex device_index) {
36+
if (device_index == -1) {
37+
// CUDA API returns int, explicit cast to DeviceIndex (int8_t) following
38+
// ATen
39+
int tmp_device = -1;
40+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device));
41+
device_index = static_cast<DeviceIndex>(tmp_device);
42+
}
43+
44+
auto it = current_streams_.find(device_index);
45+
if (it != current_streams_.end()) {
46+
return it->second;
47+
}
48+
49+
cudaStream_t stream;
50+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream));
51+
setCurrentCUDAStream(stream, device_index);
52+
return stream;
53+
}
54+
55+
CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept
56+
: original_device_index_(other.original_device_index_),
57+
current_device_index_(other.current_device_index_) {
58+
// Mark the moved-from object as "already restored" so its destructor doesn't
59+
// try to restore the device
60+
other.original_device_index_ = other.current_device_index_;
61+
}
62+
63+
CUDAGuard::~CUDAGuard() {
64+
if (original_device_index_ != current_device_index_) {
65+
// DeviceIndex (int8_t) implicitly widens to int for cudaSetDevice
66+
cudaError_t err = cudaSetDevice(original_device_index_);
67+
if (err != cudaSuccess) {
68+
ET_LOG(
69+
Error,
70+
"~CUDAGuard: Failed to restore device to %d: %s",
71+
static_cast<int>(original_device_index_),
72+
cudaGetErrorString(err));
73+
}
74+
}
75+
}
76+
77+
Error CUDAGuard::set_index(DeviceIndex device_index) {
78+
// CUDA API returns int, explicit cast to DeviceIndex (int8_t) following ATen
79+
int tmp_device = -1;
80+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device));
81+
82+
original_device_index_ = static_cast<DeviceIndex>(tmp_device);
83+
current_device_index_ = device_index;
84+
85+
if (current_device_index_ != original_device_index_) {
86+
// DeviceIndex (int8_t) implicitly widens to int for cudaSetDevice
87+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_));
88+
}
89+
90+
return Error::Ok;
91+
}
92+
93+
Result<CUDAGuard> CUDAGuard::create(DeviceIndex device_index) {
94+
CUDAGuard guard; // Fixed: Removed () to create a variable, not a function
95+
ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index));
96+
return guard;
97+
}
98+
99+
CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept
100+
: device_guard_(std::move(other.device_guard_)),
101+
original_stream_(other.original_stream_),
102+
current_stream_(other.current_stream_),
103+
device_index_(other.device_index_) {
104+
// Mark the moved-from object as "already restored" so its destructor doesn't
105+
// try to restore the stream
106+
other.original_stream_ = other.current_stream_;
107+
}
108+
109+
CUDAStreamGuard::~CUDAStreamGuard() {
110+
// Restore the original stream unless this object was moved-from.
111+
// After a move, original_stream_ == current_stream_, which indicates
112+
// the moved-from object should not restore.
113+
// Note: nullptr is a valid stream value (represents the default stream),
114+
// so we must restore even if original_stream_ is nullptr.
115+
if (original_stream_ != current_stream_) {
116+
Error err = setCurrentCUDAStream(original_stream_, device_index_);
117+
if (err != Error::Ok) {
118+
ET_LOG(
119+
Error,
120+
"~CUDAStreamGuard: Failed to restore stream for device %d",
121+
static_cast<int>(device_index_));
122+
}
123+
}
124+
}
125+
126+
Error CUDAStreamGuard::set_stream(
127+
cudaStream_t stream,
128+
DeviceIndex device_index) {
129+
auto result = getCurrentCUDAStream(device_index);
130+
if (!result.ok()) {
131+
ET_LOG(
132+
Error,
133+
"Failed to get current stream for device %d",
134+
static_cast<int>(device_index));
135+
return result.error();
136+
}
137+
138+
original_stream_ = result.get();
139+
current_stream_ = stream;
140+
device_index_ = device_index;
141+
142+
ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index));
143+
144+
return Error::Ok;
145+
}
146+
147+
Result<CUDAStreamGuard> CUDAStreamGuard::create(
148+
cudaStream_t stream,
149+
DeviceIndex device_index) {
150+
auto guard_result = CUDAGuard::create(device_index);
151+
ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error());
152+
153+
CUDAStreamGuard stream_guard(std::move(guard_result.get()));
154+
ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index));
155+
156+
return stream_guard;
157+
}
158+
159+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)