Skip to content

Commit d1c34d0

Browse files
committed
Linux implementation of hat::memory_protector
1 parent 82f4466 commit d1c34d0

File tree

4 files changed

+109
-2
lines changed

4 files changed

+109
-2
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ set(LIBHAT_SRC
4747
src/Scanner.cpp
4848
src/System.cpp
4949

50+
src/os/linux/MemoryProtector.cpp
51+
5052
src/os/unix/System.cpp
5153

5254
src/os/win32/MemoryProtector.cpp

src/Utils.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <bit>
4+
#include <cstdint>
5+
6+
namespace hat::detail {
7+
8+
constexpr uintptr_t fast_align_down(uintptr_t address, size_t alignment) {
9+
return address & ~static_cast<uintptr_t>(alignment - 1);
10+
}
11+
12+
constexpr uintptr_t fast_align_up(uintptr_t address, size_t alignment) {
13+
return (address + alignment - 1) & ~static_cast<uintptr_t>(alignment - 1);
14+
}
15+
}

src/os/linux/MemoryProtector.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <libhat/Defines.hpp>
2+
#ifdef LIBHAT_LINUX
3+
4+
#include <charconv>
5+
#include <fstream>
6+
#include <optional>
7+
#include <string>
8+
9+
#include <libhat/MemoryProtector.hpp>
10+
#include <libhat/System.hpp>
11+
#include "../../Utils.hpp"
12+
13+
#include <sys/mman.h>
14+
15+
namespace hat {
16+
17+
static int to_system_prot(const protection flags) {
18+
int prot = 0;
19+
if (static_cast<bool>(flags & protection::Read)) prot |= PROT_READ;
20+
if (static_cast<bool>(flags & protection::Write)) prot |= PROT_WRITE;
21+
if (static_cast<bool>(flags & protection::Execute)) prot |= PROT_EXEC;
22+
return prot;
23+
}
24+
25+
static std::optional<int> get_page_prot(const uintptr_t address) {
26+
std::ifstream f("/proc/self/maps");
27+
std::string s;
28+
while (std::getline(f, s)) {
29+
const char* it = s.data();
30+
const char* end = s.data() + s.size();
31+
std::from_chars_result res{};
32+
33+
uintptr_t pageBegin;
34+
res = std::from_chars(it, end, pageBegin, 16);
35+
if (res.ec != std::errc{} || res.ptr == end) {
36+
continue;
37+
}
38+
it = res.ptr + 1; // +1 to skip the hyphen
39+
40+
uintptr_t pageEnd;
41+
res = std::from_chars(it, end, pageEnd, 16);
42+
if (res.ec != std::errc{} || res.ptr == end) {
43+
continue;
44+
}
45+
it = res.ptr + 1; // +1 to skip the space
46+
47+
std::string_view remaining{it, end};
48+
if (address >= pageBegin && address < pageEnd && remaining.size() >= 3) {
49+
int prot = 0;
50+
if (remaining[0] == 'r') prot |= PROT_READ;
51+
if (remaining[1] == 'w') prot |= PROT_WRITE;
52+
if (remaining[2] == 'x') prot |= PROT_EXEC;
53+
return prot;
54+
}
55+
}
56+
return std::nullopt;
57+
}
58+
59+
memory_protector::memory_protector(const uintptr_t address, const size_t size, const protection flags) : address(address), size(size) {
60+
const auto pageSize = hat::get_system().page_size;
61+
62+
const auto oldProt = get_page_prot(address);
63+
if (!oldProt) {
64+
return; // Failure indicated via is_set()
65+
}
66+
67+
this->oldProtection = static_cast<uint32_t>(*oldProt);
68+
this->set = 0 == mprotect(
69+
reinterpret_cast<void*>(detail::fast_align_down(address, pageSize)),
70+
static_cast<size_t>(detail::fast_align_up(size, pageSize)),
71+
to_system_prot(flags)
72+
);
73+
}
74+
75+
void memory_protector::restore() {
76+
const auto pageSize = hat::get_system().page_size;
77+
mprotect(
78+
reinterpret_cast<void*>(detail::fast_align_down(address, pageSize)),
79+
static_cast<size_t>(detail::fast_align_up(size, pageSize)),
80+
this->oldProtection
81+
);
82+
}
83+
}
84+
#endif

src/os/win32/MemoryProtector.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
#include <Windows.h>
77

88
namespace hat {
9-
static DWORD ToWinProt(const protection flags) {
9+
10+
static DWORD to_system_prot(const protection flags) {
1011
const bool r = static_cast<bool>(flags & protection::Read);
1112
const bool w = static_cast<bool>(flags & protection::Write);
1213
const bool x = static_cast<bool>(flags & protection::Execute);
@@ -20,7 +21,12 @@ namespace hat {
2021
}
2122

2223
memory_protector::memory_protector(const uintptr_t address, const size_t size, const protection flags) : address(address), size(size) {
23-
this->set = 0 != VirtualProtect(reinterpret_cast<LPVOID>(this->address), this->size, ToWinProt(flags), reinterpret_cast<PDWORD>(&this->oldProtection));
24+
this->set = 0 != VirtualProtect(
25+
reinterpret_cast<LPVOID>(this->address),
26+
this->size,
27+
to_system_prot(flags),
28+
reinterpret_cast<PDWORD>(&this->oldProtection)
29+
);
2430
}
2531

2632
void memory_protector::restore() {

0 commit comments

Comments
 (0)