|
| 1 | +#include <libhat/Defines.hpp> |
| 2 | +#ifdef LIBHAT_LINUX |
| 3 | + |
| 4 | +#include <charconv> |
| 5 | +#include <fstream> |
| 6 | +#include <optional> |
| 7 | +#include <string> |
| 8 | + |
| 9 | +#include <libhat/MemoryProtector.hpp> |
| 10 | +#include <libhat/System.hpp> |
| 11 | +#include "../../Utils.hpp" |
| 12 | + |
| 13 | +#include <sys/mman.h> |
| 14 | + |
| 15 | +namespace hat { |
| 16 | + |
| 17 | + static int to_system_prot(const protection flags) { |
| 18 | + int prot = 0; |
| 19 | + if (static_cast<bool>(flags & protection::Read)) prot |= PROT_READ; |
| 20 | + if (static_cast<bool>(flags & protection::Write)) prot |= PROT_WRITE; |
| 21 | + if (static_cast<bool>(flags & protection::Execute)) prot |= PROT_EXEC; |
| 22 | + return prot; |
| 23 | + } |
| 24 | + |
| 25 | + static std::optional<int> get_page_prot(const uintptr_t address) { |
| 26 | + std::ifstream f("/proc/self/maps"); |
| 27 | + std::string s; |
| 28 | + while (std::getline(f, s)) { |
| 29 | + const char* it = s.data(); |
| 30 | + const char* end = s.data() + s.size(); |
| 31 | + std::from_chars_result res{}; |
| 32 | + |
| 33 | + uintptr_t pageBegin; |
| 34 | + res = std::from_chars(it, end, pageBegin, 16); |
| 35 | + if (res.ec != std::errc{} || res.ptr == end) { |
| 36 | + continue; |
| 37 | + } |
| 38 | + it = res.ptr + 1; // +1 to skip the hyphen |
| 39 | + |
| 40 | + uintptr_t pageEnd; |
| 41 | + res = std::from_chars(it, end, pageEnd, 16); |
| 42 | + if (res.ec != std::errc{} || res.ptr == end) { |
| 43 | + continue; |
| 44 | + } |
| 45 | + it = res.ptr + 1; // +1 to skip the space |
| 46 | + |
| 47 | + std::string_view remaining{it, end}; |
| 48 | + if (address >= pageBegin && address < pageEnd && remaining.size() >= 3) { |
| 49 | + int prot = 0; |
| 50 | + if (remaining[0] == 'r') prot |= PROT_READ; |
| 51 | + if (remaining[1] == 'w') prot |= PROT_WRITE; |
| 52 | + if (remaining[2] == 'x') prot |= PROT_EXEC; |
| 53 | + return prot; |
| 54 | + } |
| 55 | + } |
| 56 | + return std::nullopt; |
| 57 | + } |
| 58 | + |
| 59 | + memory_protector::memory_protector(const uintptr_t address, const size_t size, const protection flags) : address(address), size(size) { |
| 60 | + const auto pageSize = hat::get_system().page_size; |
| 61 | + |
| 62 | + const auto oldProt = get_page_prot(address); |
| 63 | + if (!oldProt) { |
| 64 | + return; // Failure indicated via is_set() |
| 65 | + } |
| 66 | + |
| 67 | + this->oldProtection = static_cast<uint32_t>(*oldProt); |
| 68 | + this->set = 0 == mprotect( |
| 69 | + reinterpret_cast<void*>(detail::fast_align_down(address, pageSize)), |
| 70 | + static_cast<size_t>(detail::fast_align_up(size, pageSize)), |
| 71 | + to_system_prot(flags) |
| 72 | + ); |
| 73 | + } |
| 74 | + |
| 75 | + void memory_protector::restore() { |
| 76 | + const auto pageSize = hat::get_system().page_size; |
| 77 | + mprotect( |
| 78 | + reinterpret_cast<void*>(detail::fast_align_down(address, pageSize)), |
| 79 | + static_cast<size_t>(detail::fast_align_up(size, pageSize)), |
| 80 | + this->oldProtection |
| 81 | + ); |
| 82 | + } |
| 83 | +} |
| 84 | +#endif |
0 commit comments