Skip to content

Commit 5514b5f

Browse files
committed
Fixed COM driver revokes
1 parent 3cd94c1 commit 5514b5f

File tree

11 files changed

+377
-218
lines changed

11 files changed

+377
-218
lines changed

driver_00Amethyst/DriverService.cpp

Lines changed: 37 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,14 @@
33
#include <ranges>
44
#include <RpcProxy.h>
55
#include <shellapi.h>
6-
#include <wil/cppwinrt_helpers.h>
76

87
#include "constants.hpp"
98
#include "Logging.h"
109
#include "util/color.hpp"
1110

12-
// Wide String to UTF8 String
13-
inline std::string WStringToString(const std::wstring& w_str)
14-
{
15-
const int count = WideCharToMultiByte(CP_UTF8, 0, w_str.c_str(), w_str.length(), nullptr, 0, nullptr, nullptr);
16-
std::string str(count, 0);
17-
WideCharToMultiByte(CP_UTF8, 0, w_str.c_str(), -1, str.data(), count, nullptr, nullptr);
18-
return str;
19-
}
20-
2111
DWORD DriverService::proxy_stub_registration_cookie_ = 0;
2212

23-
DriverService::DriverService() : register_cookie_(0)
24-
{
25-
}
13+
DriverService::DriverService() = default;
2614

2715
HRESULT DriverService::GetVersion(DWORD* apiVersion) noexcept
2816
{
@@ -36,42 +24,46 @@ HRESULT DriverService::GetVersion(DWORD* apiVersion) noexcept
3624

3725
HRESULT DriverService::SetTrackerState(dTrackerBase tracker)
3826
{
27+
if (tracker_vector_ == nullptr) return E_FAIL;
28+
3929
// HMD pose override
4030
if (tracker.Role == TrackerHead)
4131
return EnableOverride(0, tracker.ConnectionState);
4232

4333
// Normal case
44-
if (tracker_vector_.contains(static_cast<ITrackerType>(tracker.Role)))
34+
if (tracker_vector_->contains(static_cast<ITrackerType>(tracker.Role)))
4535
{
4636
// Create a handle to the updated (native) tracker
47-
const auto p_tracker = &tracker_vector_[static_cast<ITrackerType>(tracker.Role)];
37+
const auto p_tracker = &tracker_vector_->at(static_cast<ITrackerType>(tracker.Role));
4838

4939
// Check the state and attempts spawning the tracker
5040
if (!p_tracker->is_added() && !p_tracker->spawn())
5141
{
52-
logMessage(std::format("Couldn't spawn tracker with ID {} due to an unknown native exception.",
42+
logMessage(std::format("Couldn't spawn tracker ID {} due to an unknown native exception.",
5343
static_cast<int>(tracker.Role)));
5444
return E_FAIL; // Failure
5545
}
5646

5747
// Set the state of the native tracker
5848
p_tracker->set_state(tracker.ConnectionState);
59-
logMessage(std::format("Unmanaged (native) tracker with ID {} state has been set to {}.",
49+
logMessage(std::format("Tracker ID {} state set to {}.",
6050
static_cast<int>(tracker.Role), tracker.ConnectionState == 1));
6151

6252
// Call the VR update handler and compose the result
63-
tracker_vector_[static_cast<ITrackerType>(tracker.Role)].update();
53+
tracker_vector_->at(static_cast<ITrackerType>(tracker.Role)).update();
6454
return S_OK;
6555
}
6656

67-
logMessage(std::format("Couldn't spawn tracker with ID {}. The tracker index was out of bounds.",
57+
logMessage(std::format("Couldn't spawn tracker ID {}. The tracker index was out of bounds.",
6858
static_cast<int>(tracker.Role)));
6959

7060
return ERROR_INVALID_INDEX; // Failure
7161
}
7262

7363
HRESULT DriverService::UpdateTracker(dTrackerBase tracker)
7464
{
65+
if (tracker_vector_ == nullptr) return E_FAIL;
66+
7567
// HMD pose override
7668
if (tracker.Role == TrackerHead)
7769
{
@@ -84,12 +76,12 @@ HRESULT DriverService::UpdateTracker(dTrackerBase tracker)
8476
}
8577

8678
// Normal case
87-
if (tracker_vector_.contains(static_cast<ITrackerType>(tracker.Role)))
79+
if (tracker_vector_->contains(static_cast<ITrackerType>(tracker.Role)))
8880
{
8981
// Update the pose of the passed tracker
90-
if (!tracker_vector_[static_cast<ITrackerType>(tracker.Role)].set_pose(tracker))
82+
if (!tracker_vector_->at(static_cast<ITrackerType>(tracker.Role)).set_pose(tracker))
9183
{
92-
logMessage(std::format("Couldn't spawn tracker with ID {} due to an unknown native exception.",
84+
logMessage(std::format("Couldn't spawn tracker ID {} due to an unknown native exception.",
9385
static_cast<int>(tracker.Role)));
9486
return E_FAIL; // Failure
9587
}
@@ -98,7 +90,7 @@ HRESULT DriverService::UpdateTracker(dTrackerBase tracker)
9890
return S_OK;
9991
}
10092

101-
logMessage(std::format("Couldn't spawn tracker with ID {}. The tracker index was out of bounds.",
93+
logMessage(std::format("Couldn't spawn tracker ID {}. The tracker index was out of bounds.",
10294
static_cast<int>(tracker.Role)));
10395

10496
return ERROR_INVALID_INDEX; // Failure
@@ -171,9 +163,9 @@ HRESULT DriverService::UpdateInputBoolean(dTrackerType tracker, wchar_t* path, b
171163
return ERROR_EMPTY; // Compose the reply
172164
}
173165

174-
if (tracker_vector_.contains(static_cast<ITrackerType>(tracker)))
175-
return tracker_vector_[static_cast<ITrackerType>(tracker)]
176-
.update_input(WStringToString(path), static_cast<bool>(value))
166+
if (tracker_vector_->contains(static_cast<ITrackerType>(tracker)))
167+
return tracker_vector_->at(static_cast<ITrackerType>(tracker))
168+
.update_input(WStringToString(path), static_cast<bool>(value))
177169
? S_OK
178170
: ERROR_INVALID_ACCESS;
179171

@@ -188,9 +180,9 @@ HRESULT DriverService::UpdateInputScalar(dTrackerType tracker, wchar_t* path, fl
188180
return ERROR_EMPTY; // Compose the reply
189181
}
190182

191-
if (tracker_vector_.contains(static_cast<ITrackerType>(tracker)))
192-
return tracker_vector_[static_cast<ITrackerType>(tracker)]
193-
.update_input(WStringToString(path), value)
183+
if (tracker_vector_->contains(static_cast<ITrackerType>(tracker)))
184+
return tracker_vector_->at(static_cast<ITrackerType>(tracker))
185+
.update_input(WStringToString(path), value)
194186
? S_OK
195187
: ERROR_INVALID_ACCESS;
196188

@@ -228,45 +220,28 @@ void DriverService::UninstallProxyStub()
228220
}
229221
}
230222

231-
std::optional<winrt::hresult_error> DriverService::SetupService(const _GUID clsid)
223+
void DriverService::TrackerVector(std::map<ITrackerType, BodyTracker>* const& vector)
232224
{
233-
try
234-
{
235-
InstallProxyStub();
236-
237-
winrt::com_ptr<IUnknown> service;
238-
winrt::check_hresult(RegisterActiveObject(
239-
static_cast<IDriverService*>(this),
240-
clsid, ACTIVEOBJECT_STRONG,
241-
&register_cookie_));
242-
243-
winrt::check_hresult(GetActiveObject(
244-
clsid, nullptr, service.put()));
245-
}
246-
catch (const winrt::hresult_error& e)
247-
{
248-
return e;
249-
}
250-
catch (...)
251-
{
252-
return winrt::hresult_error(-1);
253-
}
254-
255-
return std::nullopt;
225+
tracker_vector_ = vector;
256226
}
257227

258-
void DriverService::UpdateTrackers()
228+
void DriverService::RebuildCallback(IRebuildCallback* callback)
259229
{
260-
for (auto& tracker : tracker_vector_ | std::views::values)
261-
tracker.update(); // Update all
230+
rebuild_callback_ = callback;
262231
}
263232

264-
void DriverService::AddTracker(const std::string& serial, const ITrackerType role)
233+
ULONG DriverService::Release() noexcept
265234
{
266-
tracker_vector_[role] = BodyTracker(serial, role);
267-
}
235+
const auto count = implements::Release();
236+
logMessage(std::format("COM ref released, running total: {}", count));
268237

269-
std::map<ITrackerType, BodyTracker> DriverService::TrackerVector()
270-
{
271-
return tracker_vector_;
238+
if (count == 1 && rebuild_callback_)
239+
{
240+
logMessage("Client disconnected");
241+
logMessage("COM revocation detected, requesting rebuild!");
242+
rebuild_callback_->OnRebuildRequested();
243+
rebuild_callback_ = nullptr; // Clear the callback
244+
}
245+
246+
return count;
272247
}

driver_00Amethyst/DriverService.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,24 @@
1414
#include "driver_Amethyst.h"
1515
#include "wilx.hpp"
1616
#include <functional>
17+
#include "Logging.h"
18+
19+
namespace winrt
20+
{
21+
struct hresult_error;
22+
}
1723

1824
extern "C" {
1925
_Check_return_ HRESULT STDAPICALLTYPE DLLGETCLASSOBJECT_ENTRY(
2026
_In_ REFCLSID rclsid, _In_ REFIID riid, _Outptr_ void** ppv);
2127
}
2228

29+
struct IRebuildCallback
30+
{
31+
virtual void OnRebuildRequested() = 0;
32+
virtual ~IRebuildCallback() = default;
33+
};
34+
2335
class DriverService : public winrt::implements<
2436
DriverService, IDriverService, IVersionedApi, winrt::non_agile>
2537
{
@@ -52,18 +64,17 @@ class DriverService : public winrt::implements<
5264
static void InstallProxyStub();
5365
static void UninstallProxyStub();
5466

55-
std::optional<winrt::hresult_error> SetupService(_GUID clsid);
67+
void TrackerVector(std::map<ITrackerType, BodyTracker>* const& vector);
68+
void RebuildCallback(IRebuildCallback* callback);
5669

57-
void UpdateTrackers();
58-
void AddTracker(const std::string& serial, const ITrackerType role);
59-
std::map<ITrackerType, BodyTracker> TrackerVector();
70+
ULONG __stdcall Release() noexcept override;
6071

6172
void RegisterDriverPoseHandler(const std::function<HRESULT(const uint32_t& id, dDriverPose pose)>& handler);
6273
void RegisterOverrideSetHandler(const std::function<HRESULT(const uint32_t& id, bool isEnabled)>& handler);
6374

6475
private:
65-
DWORD register_cookie_;
66-
std::map<ITrackerType, BodyTracker> tracker_vector_;
76+
IRebuildCallback* rebuild_callback_ = nullptr;
77+
std::map<ITrackerType, BodyTracker>* tracker_vector_;
6778

6879
std::function<HRESULT(const uint32_t& id, dDriverPose pose)> pose_update_handler_;
6980
std::function<HRESULT(const uint32_t& id, bool isEnabled)> override_set_handler_;

driver_00Amethyst/Logging.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,13 @@ inline void logMessageVerbose(const std::string& message, ...)
4646
va_end(args);
4747

4848
OutputDebugStringA(buffer);
49-
}
49+
}
50+
51+
// Wide String to UTF8 String
52+
inline std::string WStringToString(const std::wstring& w_str)
53+
{
54+
const int count = WideCharToMultiByte(CP_UTF8, 0, w_str.c_str(), w_str.length(), nullptr, 0, nullptr, nullptr);
55+
std::string str(count, 0);
56+
WideCharToMultiByte(CP_UTF8, 0, w_str.c_str(), -1, str.data(), count, nullptr, nullptr);
57+
return str;
58+
}

0 commit comments

Comments
 (0)