diff --git a/include/kf/FltCommunicationPort.h b/include/kf/FltCommunicationPort.h index b7fbca2..457c7c0 100644 --- a/include/kf/FltCommunicationPort.h +++ b/include/kf/FltCommunicationPort.h @@ -200,7 +200,7 @@ namespace kf PMDL inputMdl = nullptr; PMDL outputMdl = nullptr; - NTSTATUS status = [&]() + NTSTATUS status = [&]() -> NTSTATUS { // // Lock user buffers so __try/__except is required only here and not in the handler->onMessage diff --git a/include/kf/ScopeFailure.h b/include/kf/ScopeFailure.h index 1d8cbf0..2bdc7c7 100644 --- a/include/kf/ScopeFailure.h +++ b/include/kf/ScopeFailure.h @@ -7,11 +7,11 @@ namespace kf { - template + template requires std::convertible_to class ScopeFailure { public: - ScopeFailure(NTSTATUS& status, F&& f) : m_status(status), m_f(f) + ScopeFailure(T& status, F&& f) : m_status(status), m_f(f) { } @@ -30,22 +30,23 @@ namespace kf ScopeFailure& operator=(const ScopeFailure&); private: - NTSTATUS& m_status; - F m_f; + T& m_status; + F m_f; }; + template requires std::convertible_to struct MakeScopeFailure { - MakeScopeFailure(NTSTATUS& status) : m_status(status) + MakeScopeFailure(T& status) : m_status(status) { } template - ScopeFailure operator+=(F&& f) + ScopeFailure operator+=(F&& f) { - return ScopeFailure(m_status, std::move(f)); + return ScopeFailure(m_status, std::move(f)); } - NTSTATUS& m_status; + T& m_status; }; } diff --git a/test/ScopeFailureTest.cpp b/test/ScopeFailureTest.cpp index 6ee26ce..fe07291 100644 --- a/test/ScopeFailureTest.cpp +++ b/test/ScopeFailureTest.cpp @@ -85,7 +85,7 @@ SCENARIO("SCOPE_FAILURE macro") SCOPE_FAILURE(status) { value++; - scopedStatus = status;; + scopedStatus = status; }; status = STATUS_CANT_WAIT; @@ -97,4 +97,54 @@ SCENARIO("SCOPE_FAILURE macro") REQUIRE(scopedStatus == STATUS_CANT_WAIT); } } + + GIVEN("SCOPE_FAILURE macro with a type convertible to NTSTATUS (unsuccessful case)") + { + struct StatusWrapper + { + NTSTATUS status; + operator NTSTATUS() const { return status; } + }; + + StatusWrapper status{ STATUS_SUCCESS }; + int value = 0; + + { + status.status = STATUS_ACCESS_DENIED; + + SCOPE_FAILURE(status) + { + value++; + }; + } + + THEN("Scoped function should be called for unsuccessful value") + { + REQUIRE(value == 1); + } + } + + GIVEN("SCOPE_FAILURE macro with a type convertible to NTSTATUS (successful case)") + { + struct StatusWrapper + { + NTSTATUS status; + operator NTSTATUS() const { return status; } + }; + + StatusWrapper status{ STATUS_SUCCESS }; + int value = 0; + + { + SCOPE_FAILURE(status) + { + value++; + }; + } + + THEN("Scoped function should not be called when value is successful") + { + REQUIRE(value == 0); + } + } }