Skip to content

Commit 70015e0

Browse files
MediaPipe Teamcopybara-github
authored andcommitted
Avoids the sharing of GL contexts between nested mediapipe graphs.
PiperOrigin-RevId: 730837827
1 parent 962d7ff commit 70015e0

File tree

5 files changed

+50
-18
lines changed

5 files changed

+50
-18
lines changed

mediapipe/framework/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ cc_library(
376376
"//mediapipe/framework/tool:validate",
377377
"//mediapipe/framework/tool:validate_name",
378378
"//mediapipe/gpu:gpu_service",
379+
"//mediapipe/gpu:gpu_shared_data_internal",
379380
"//mediapipe/gpu:graph_support",
380381
"//mediapipe/util:cpu_util",
381382
"@com_google_absl//absl/base:core_headers",

mediapipe/framework/calculator_graph.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
#include "mediapipe/framework/validated_graph_config.h"
8181
#include "mediapipe/framework/vlog_overrides.h"
8282
#include "mediapipe/gpu/gpu_service.h"
83+
#include "mediapipe/gpu/gpu_shared_data_internal.h"
8384
#include "mediapipe/gpu/graph_support.h"
8485
#include "mediapipe/util/cpu_util.h"
8586

@@ -139,14 +140,34 @@ CalculatorGraph::CalculatorGraph() : CalculatorGraph(/*cc=*/nullptr) {}
139140
// Adopt all services from the CalculatorContext / parent graph.
140141
CalculatorGraph::CalculatorGraph(CalculatorContext* cc)
141142
: counter_factory_(std::make_unique<BasicCounterFactory>()),
142-
service_manager_(cc != nullptr ? cc->GetGraphServiceManager() : nullptr),
143143
profiler_(std::make_shared<ProfilingContext>()),
144144
scheduler_(this) {
145145
if (cc != nullptr) {
146146
// Nested graphs should not create default initialized services to avoid
147147
// collisions between newly created and inherited graphs.
148148
// TODO b/368015341- Use factory method to avoid CHECK in constructor.
149149
ABSL_CHECK_OK(DisallowServiceDefaultInitialization());
150+
151+
// Adopt all services from the parent graph except for GpuResources.
152+
const auto parent_service_manager = cc->GetGraphServiceManager();
153+
const auto parent_service_packets =
154+
parent_service_manager->ServicePackets();
155+
GraphServiceManager::ServiceMap service_packets;
156+
for (const auto& [key, packet] : parent_service_packets) {
157+
if (key == kGpuService.key) {
158+
// To avoid deadlocks when sharing the same GPU thread between
159+
// multiple graphs, we create a new GpuResources instance for each
160+
// sub-graph with a dedicated GL context / thread.
161+
auto resources = mediapipe::GpuResources::Create(
162+
*packet.Get<std::shared_ptr<mediapipe::GpuResources>>());
163+
ABSL_CHECK_OK(resources);
164+
service_packets[key] =
165+
MakePacket<std::shared_ptr<mediapipe::GpuResources>>(*resources);
166+
} else {
167+
service_packets[key] = packet;
168+
}
169+
}
170+
service_manager_.SetServicePackets(service_packets);
150171
}
151172
SetVLogOverrides();
152173
}

mediapipe/framework/graph_service_manager.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <map>
55
#include <memory>
66
#include <string>
7+
#include <utility>
78

89
#include "absl/log/absl_check.h"
910
#include "absl/status/status.h"
@@ -15,21 +16,10 @@ namespace mediapipe {
1516

1617
class GraphServiceManager {
1718
public:
18-
GraphServiceManager() = default;
19-
20-
explicit GraphServiceManager(
21-
const GraphServiceManager* external_graph_manager) {
22-
if (external_graph_manager != nullptr) {
23-
// Nested graphs inherit all graph services from their parent graph and
24-
// disable the registration of new services in the nested graph. This
25-
// ensures that all services are created during the initialization of
26-
// parent graph.
27-
service_packets_ = external_graph_manager->ServicePackets();
28-
}
29-
}
30-
3119
using ServiceMap = std::map<std::string, Packet>;
3220

21+
GraphServiceManager() = default;
22+
3323
template <typename T>
3424
absl::Status SetServiceObject(const GraphService<T>& service,
3525
std::shared_ptr<T> object) {
@@ -41,6 +31,10 @@ class GraphServiceManager {
4131

4232
absl::Status SetServicePacket(const GraphServiceBase& service, Packet p);
4333

34+
void SetServicePackets(const ServiceMap& service_packets) {
35+
service_packets_ = service_packets;
36+
}
37+
4438
template <typename T>
4539
std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) const {
4640
Packet p = GetServicePacket(service);

mediapipe/gpu/gpu_shared_data_internal.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ GpuResources::StatusOrGpuResources GpuResources::Create(
108108
return gpu_resources;
109109
}
110110

111+
GpuResources::StatusOrGpuResources GpuResources::Create(
112+
const GpuResources& gpu_resources,
113+
const MultiPoolOptions* gpu_buffer_pool_options) {
114+
return Create(gpu_resources.gl_context()->native_context(),
115+
gpu_buffer_pool_options);
116+
}
117+
111118
GpuResources::GpuResources(std::shared_ptr<GlContext> gl_context,
112119
const MultiPoolOptions* gpu_buffer_pool_options)
113120
: gl_key_context_(new GlContextMapType(),
@@ -241,9 +248,10 @@ absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
241248
// TODO: expose and use an actual ID instead of using the
242249
// canonicalized name.
243250
const std::shared_ptr<GlContext>& GpuResources::gl_context(
244-
CalculatorContext* cc) {
251+
CalculatorContext* cc) const {
245252
if (cc) {
246-
auto it = gl_key_context_->find(node_key_[cc->NodeName()]);
253+
const auto node_key_it = node_key_.find(cc->NodeName());
254+
const auto it = gl_key_context_->find(node_key_it->second);
247255
if (it != gl_key_context_->end()) {
248256
return it->second;
249257
}

mediapipe/gpu/gpu_shared_data_internal.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class GpuResources {
5757
PlatformGlContext external_context,
5858
const MultiPoolOptions* gpu_buffer_pool_options = nullptr);
5959

60+
// Creates a GpuResources instance that is shared with the GL context provided
61+
// by the gpu_resources argument.
62+
static StatusOrGpuResources Create(
63+
const GpuResources& gpu_resources,
64+
const MultiPoolOptions* gpu_buffer_pool_options = nullptr);
65+
6066
// The destructor must be defined in the implementation file so that on iOS
6167
// the correct ARC release calls are generated.
6268
~GpuResources();
@@ -65,9 +71,11 @@ class GpuResources {
6571

6672
// Shared GL context for calculators.
6773
// TODO: require passing a context or node identifier.
68-
const std::shared_ptr<GlContext>& gl_context() { return gl_context(nullptr); }
74+
const std::shared_ptr<GlContext>& gl_context() const {
75+
return gl_context(nullptr);
76+
}
6977

70-
const std::shared_ptr<GlContext>& gl_context(CalculatorContext* cc);
78+
const std::shared_ptr<GlContext>& gl_context(CalculatorContext* cc) const;
7179

7280
// Shared buffer pool.
7381
GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; }

0 commit comments

Comments
 (0)