Skip to content

Commit a501448

Browse files
committed
Add get_iids_tearoff() hook.
Add a `get_iids_tearoff()` hook to complement the `query_interface_tearoff()` hook. `query_interface_tearoff()` allows dynamically extending types with additional interfaces that can be accessed with the `QueryInterface()` method. However, these interfaces were not discoverable by the `GetIids()` method (and therefore not by the `winrt::get_interfaces()` function either). The `get_iids_tearoff()` hook allows adding the GUIDs of interfaces to the array returned by `GetIids()`. This also fixes a bug in the implementation of the `root_implements_type::is_composing` branch of `NonDelegatingGetIids()` where `*array` was updated causing the local iids to be unreachable by the caller and risking the caller reading past the end of the array.
1 parent fd0e959 commit a501448

File tree

3 files changed

+152
-5
lines changed

3 files changed

+152
-5
lines changed

strings/base_implements.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,11 @@ namespace winrt::impl
912912
return error_no_interface;
913913
}
914914

915+
virtual array_view<guid> get_iids_tearoff() const noexcept
916+
{
917+
return {};
918+
}
919+
915920
root_implements() noexcept
916921
{
917922
}
@@ -1021,7 +1026,8 @@ namespace winrt::impl
10211026
int32_t __stdcall NonDelegatingGetIids(uint32_t* count, guid** array) noexcept
10221027
{
10231028
auto const& local_iids = static_cast<D*>(this)->get_local_iids();
1024-
uint32_t const& local_count = local_iids.first;
1029+
auto tearoff_iids = get_iids_tearoff();
1030+
uint32_t const& local_count = local_iids.first + tearoff_iids.size();
10251031
if constexpr (root_implements_type::is_composing)
10261032
{
10271033
if (local_count > 0)
@@ -1033,8 +1039,10 @@ namespace winrt::impl
10331039
{
10341040
return error_bad_alloc;
10351041
}
1036-
*array = std::copy(local_iids.second, local_iids.second + local_count, *array);
1037-
std::copy(inner_iids.cbegin(), inner_iids.cend(), *array);
1042+
auto _array = *array;
1043+
_array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array);
1044+
_array = std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array);
1045+
std::copy(inner_iids.cbegin(), inner_iids.cend(), _array);
10381046
}
10391047
else
10401048
{
@@ -1051,7 +1059,9 @@ namespace winrt::impl
10511059
{
10521060
return error_bad_alloc;
10531061
}
1054-
std::copy(local_iids.second, local_iids.second + local_count, *array);
1062+
auto _array = *array;
1063+
_array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array);
1064+
std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array);
10551065
}
10561066
else
10571067
{

test/old_tests/UnitTests/Composable.cpp

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#include "pch.h"
22
#include "catch.hpp"
33
#include "winrt/Composable.h"
4+
#include <windows.foundation.h>
45

56
using namespace winrt;
67
using namespace Windows::Foundation;
78
using namespace Composable;
89

910
using namespace std::string_view_literals;
11+
using hstring = ::winrt::hstring;
1012

1113
namespace
1214
{
@@ -167,4 +169,120 @@ TEST_CASE("Composable conversions")
167169
{
168170
TestCalls(*make_self<Foo>());
169171
TestCalls(*make_self<Bar>());
170-
}
172+
}
173+
174+
namespace
175+
{
176+
// Creates an implementation of IStringable as a tearoff.
177+
HRESULT make_stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value, void** result) noexcept
178+
{
179+
struct stringable final : ABI::Windows::Foundation::IStringable
180+
{
181+
stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value) :
182+
m_object(object.as<::IInspectable>()),
183+
m_value(value)
184+
{
185+
}
186+
187+
HRESULT __stdcall ToString(HSTRING* result) noexcept final
188+
{
189+
return WindowsDuplicateString(static_cast<HSTRING>(get_abi(m_value)), result);
190+
}
191+
192+
HRESULT __stdcall QueryInterface(GUID const& id, void** result) noexcept final
193+
{
194+
if (is_guid_of<IStringable>(id))
195+
{
196+
*result = static_cast<ABI::Windows::Foundation::IStringable*>(this);
197+
AddRef();
198+
return S_OK;
199+
}
200+
201+
return m_object->QueryInterface(id, result);
202+
}
203+
204+
ULONG __stdcall AddRef() noexcept final
205+
{
206+
return 1 + m_references.fetch_add(1, std::memory_order_relaxed);
207+
}
208+
209+
ULONG __stdcall Release() noexcept final
210+
{
211+
uint32_t const remaining = m_references.fetch_sub(1, std::memory_order_relaxed) - 1;
212+
213+
if (remaining == 0)
214+
{
215+
delete this;
216+
}
217+
218+
return remaining;
219+
}
220+
221+
HRESULT __stdcall GetIids(ULONG* count, GUID** iids) noexcept final
222+
{
223+
return m_object->GetIids(count, iids);
224+
}
225+
226+
HRESULT __stdcall GetRuntimeClassName(HSTRING* result) noexcept final
227+
{
228+
return m_object->GetRuntimeClassName(result);
229+
}
230+
231+
HRESULT __stdcall GetTrustLevel(::TrustLevel* result) noexcept final
232+
{
233+
return m_object->GetTrustLevel(result);
234+
}
235+
236+
private:
237+
238+
com_ptr<::IInspectable> m_object;
239+
hstring m_value;
240+
std::atomic<uint32_t> m_references{ 1 };
241+
};
242+
243+
*result = new (std::nothrow) stringable(object, value);
244+
return *result ? S_OK : E_OUTOFMEMORY;
245+
}
246+
}
247+
248+
TEST_CASE("Composable tearoff")
249+
{
250+
static std::array<winrt::guid, 1> tearoff_iids{ winrt::guid_of<IStringable>() };
251+
252+
struct Tearoff : DerivedT<Tearoff, IClosable>
253+
{
254+
void Close()
255+
{
256+
}
257+
258+
int32_t query_interface_tearoff(winrt::guid const& id, void** result) const noexcept final
259+
{
260+
if (is_guid_of<IStringable>(id))
261+
{
262+
return make_stringable(*this, L"ToString", result);
263+
}
264+
265+
*result = nullptr;
266+
return E_NOINTERFACE;
267+
}
268+
269+
winrt::array_view<winrt::guid> get_iids_tearoff() const noexcept final
270+
{
271+
return winrt::array_view(tearoff_iids);
272+
}
273+
};
274+
275+
auto object = make<Tearoff>();
276+
auto ifaces = get_interfaces(object);
277+
278+
REQUIRE(object.as<IClosable>());
279+
REQUIRE(object.as<IStringable>());
280+
// IBaseOverrides is repeated twice for some reason, so actual size is 7 but there are only 6 unique interfaces
281+
REQUIRE(ifaces.size() >= 6);
282+
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<Base>()) != ifaces.end());
283+
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IBaseProtected>()) != ifaces.end());
284+
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IBaseOverrides>()) != ifaces.end());
285+
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<Derived>()) != ifaces.end());
286+
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IClosable>()) != ifaces.end());
287+
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IStringable>()) != ifaces.end());
288+
}

test/test/tearoff.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ namespace
164164
*result = nullptr;
165165
return E_NOINTERFACE;
166166
}
167+
168+
inline static std::array<winrt::guid, 1> tearoff_iids = {winrt::guid_of<IPersist>()};
169+
170+
winrt::array_view<winrt::guid> get_iids_tearoff() const noexcept final
171+
{
172+
return winrt::array_view(tearoff_iids);
173+
}
167174
};
168175

169176
struct RuntimeType : winrt::implements<RuntimeType, winrt::IClosable>
@@ -195,6 +202,13 @@ namespace
195202

196203
*result = nullptr;
197204
return E_NOINTERFACE;
205+
}
206+
207+
inline static std::array<winrt::guid, 1> tearoff_iids{winrt::guid_of<winrt::IStringable>()};
208+
209+
winrt::array_view<winrt::guid> get_iids_tearoff() const noexcept final
210+
{
211+
return winrt::array_view(tearoff_iids);
198212
}
199213
};
200214
}
@@ -215,6 +229,11 @@ TEST_CASE("tearoff")
215229
REQUIRE(S_OK == persist->GetClassID(&result));
216230
REQUIRE(winrt::is_guid_of<IPersist>(result));
217231

232+
winrt::com_array<winrt::guid> iids = winrt::get_interfaces(closable);
233+
REQUIRE(iids.size() == 2);
234+
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IClosable>()) != iids.end());
235+
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<IPersist>()) != iids.end());
236+
218237
// query_interface_tearoff happily ignores any other queries.
219238
REQUIRE(closable.try_as<winrt::IActivationFactory>() == nullptr);
220239

0 commit comments

Comments
 (0)