|
1 | 1 | #include "message_fs_repository.h" |
| 2 | +#include <algorithm> |
2 | 3 | #include <fstream> |
3 | 4 | #include <mutex> |
4 | 5 | #include "utils/result.hpp" |
@@ -52,7 +53,61 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, |
52 | 53 | auto mutex = GrabMutex(thread_id); |
53 | 54 | std::shared_lock<std::shared_mutex> lock(*mutex); |
54 | 55 |
|
55 | | - return ReadMessageFromFile(thread_id); |
| 56 | + auto read_result = ReadMessageFromFile(thread_id); |
| 57 | + if (read_result.has_error()) { |
| 58 | + return cpp::fail(read_result.error()); |
| 59 | + } |
| 60 | + |
| 61 | + std::vector<OpenAi::Message> messages = std::move(read_result.value()); |
| 62 | + |
| 63 | + if (!run_id.empty()) { |
| 64 | + messages.erase(std::remove_if(messages.begin(), messages.end(), |
| 65 | + [&run_id](const OpenAi::Message& msg) { |
| 66 | + return msg.run_id != run_id; |
| 67 | + }), |
| 68 | + messages.end()); |
| 69 | + } |
| 70 | + |
| 71 | + std::sort(messages.begin(), messages.end(), |
| 72 | + [&order](const OpenAi::Message& a, const OpenAi::Message& b) { |
| 73 | + if (order == "desc") { |
| 74 | + return a.created_at > b.created_at; |
| 75 | + } |
| 76 | + return a.created_at < b.created_at; |
| 77 | + }); |
| 78 | + |
| 79 | + auto start_it = messages.begin(); |
| 80 | + auto end_it = messages.end(); |
| 81 | + |
| 82 | + if (!after.empty()) { |
| 83 | + start_it = std::find_if( |
| 84 | + messages.begin(), messages.end(), |
| 85 | + [&after](const OpenAi::Message& msg) { return msg.id == after; }); |
| 86 | + if (start_it != messages.end()) { |
| 87 | + ++start_it; // Start from the message after the 'after' message |
| 88 | + } else { |
| 89 | + start_it = messages.begin(); |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + if (!before.empty()) { |
| 94 | + end_it = std::find_if( |
| 95 | + messages.begin(), messages.end(), |
| 96 | + [&before](const OpenAi::Message& msg) { return msg.id == before; }); |
| 97 | + } |
| 98 | + |
| 99 | + std::vector<OpenAi::Message> result; |
| 100 | + size_t distance = std::distance(start_it, end_it); |
| 101 | + size_t limit_size = static_cast<size_t>(limit); |
| 102 | + CTL_INF("Distance: " + std::to_string(distance) + |
| 103 | + ", limit_size: " + std::to_string(limit_size)); |
| 104 | + result.reserve(distance < limit_size ? distance : limit_size); |
| 105 | + |
| 106 | + for (auto it = start_it; it != end_it && result.size() < limit_size; ++it) { |
| 107 | + result.push_back(std::move(*it)); |
| 108 | + } |
| 109 | + |
| 110 | + return result; |
56 | 111 | } |
57 | 112 |
|
58 | 113 | cpp::result<OpenAi::Message, std::string> MessageFsRepository::RetrieveMessage( |
|
0 commit comments