Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions external_parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ set(binary_parser_headers
${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h
${CMAKE_CURRENT_LIST_DIR}/log_converter.h
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h
Expand All @@ -146,7 +146,7 @@ set(binary_parser_sources
${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/log_converter.cc
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/example_joiner.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "../rlclientlib/lru_dedup_cache.h"
#include "event_processors/joined_event.h"
#include "event_processors/loop.h"
#include "joiners/i_joiner.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/i_joiner.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#include "../rlclientlib/lru_dedup_cache.h"
#include "event_processors/reward.h"
#include "generated/v2/CbEvent_generated.h"
#include "generated/v2/FileFormat_generated.h"
#include "generated/v2/Metadata_generated.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/unit_tests/test_lru_dedup_cache.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <boost/test/unit_test.hpp>

#include "lru_dedup_cache.h"
#include "../rlclientlib/lru_dedup_cache.h"
#include "parse_example_external.h"
#include "test_common.h"
#include "vw/config/options_cli.h"
Expand Down
10 changes: 10 additions & 0 deletions include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ class live_model
*/
int init(api_status* status = nullptr);

/**
* @brief Load dedup cache.
* Load the dedup cache from the specified file. This cache is used to
* prevent duplicate actions from being sent to the online trainer.
* @param hash Hash of the dedup cache
* @param action_str Action string
* @return int Return error code. This will also be returned in the api_status object
*/
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status);

/**
* @brief Choose an action, given a list of actions, action features and context features. The
* inference library chooses an action by creating a probability distribution over the actions
Expand Down
1 change: 1 addition & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class i_model
{
public:
virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0;
virtual int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) = 0;
virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) = 0;
virtual int choose_continuous_action(string_view features, float& action, float& pdf_value,
Expand Down
2 changes: 2 additions & 0 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(PROJECT_SOURCES
logger/logger_facade.cc
logger/preamble.cc
logger/preamble_sender.cc
lru_dedup_cache.cc
model_mgmt/data_callback_fn.cc
model_mgmt/empty_data_transport.cc
model_mgmt/file_model_loader.cc
Expand Down Expand Up @@ -149,6 +150,7 @@ set(PROJECT_PRIVATE_HEADERS
logger/async_batcher.h
logger/event_logger.h
logger/logger_facade.h
lru_dedup_cache.h
model_mgmt/data_callback_fn.h
model_mgmt/empty_data_transport.h
model_mgmt/file_model_loader.h
Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/extensions/onnx/src/onnx_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ int onnx_model::update(const model_management::model_data& data, bool& model_rea
return error_code::success;
}

// TODO: Implement LRU cache for ONNX models.
int onnx_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return error_code::not_supported;
}

int onnx_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/extensions/onnx/src/onnx_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class onnx_model : public model_management::i_model
public:
onnx_model(i_trace* trace_logger, const char* app_id, const char* output_name, bool use_unstructured_input);
int update(const model_management::model_data& data, bool& model_ready, api_status* status = nullptr) override;
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;

Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ std::vector<int> live_model::c_array_to_vector(const int* c_array, size_t array_
return std::vector<int>(c_array, c_array + array_size);
}

int live_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
INIT_CHECK();
return _pimpl->add_lru_dedup_cache(hash, std::move(action_str), status);
}

int live_model::choose_rank(
const char* event_id, string_view context_json, ranking_response& response, api_status* status)
{
Expand Down
5 changes: 5 additions & 0 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ int live_model_impl::init(api_status* status)
return error_code::success;
}

int live_model_impl::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return _model->add_lru_dedup_cache(hash, std::move(action_str), status);
}

int live_model_impl::choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class live_model_impl

int init(api_status* status);

int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status);
int choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status);
// here the event_id is auto-generated
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct lru_dedup_cache
void* context = nullptr);
bool exists(uint64_t dedup_id);
void clear(release_example_f release_example = lru_dedup_cache::noop_release_example_f, void* context = nullptr);
std::unordered_map<uint64_t, VW::example*>* get_dict() { return &dedup_examples; }

lru_dedup_cache() = default;
~lru_dedup_cache() = default;
Expand Down
8 changes: 7 additions & 1 deletion rlclientlib/vw_model/pdf_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace model_management
// We construct a VW object here to use the example parser to parse joined dsjson-style examples
// to extract the PDF.
pdf_model::pdf_model(i_trace* trace_logger, const utility::configuration& /*unused*/)
: _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf"))
: _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf", nullptr))
{
}

Expand All @@ -23,6 +23,12 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta
return error_code::success;
}

// TODO: Implement LRU cache for PDF models.
int pdf_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return error_code::not_supported;
}

int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/vw_model/pdf_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class pdf_model : public i_model
public:
pdf_model(i_trace* trace_logger, const utility::configuration& config);
int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override;
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;
int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version,
Expand Down
68 changes: 52 additions & 16 deletions rlclientlib/vw_model/safe_vw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ namespace reinforcement_learning
{
static const std::string SEED_TAG = "seed=";

safe_vw::safe_vw(std::shared_ptr<safe_vw> master) : _master(std::move(master))
safe_vw::safe_vw(std::shared_ptr<safe_vw> master, lru_dedup_cache* dedup_cache)
: _master(std::move(master)), _dedup_cache(dedup_cache)
{
_vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr);
init();
}

safe_vw::safe_vw(const char* model_data, size_t len)
safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache)
{
io_buf buf;
buf.add_file(VW::io::create_buffer_view(model_data, len));
Expand All @@ -34,7 +35,8 @@ safe_vw::safe_vw(const char* model_data, size_t len)
init();
}

safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline)
safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache)
: _dedup_cache(dedup_cache)
{
io_buf buf;
buf.add_file(VW::io::create_buffer_view(model_data, len));
Expand All @@ -43,7 +45,7 @@ safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_comma
init();
}

safe_vw::safe_vw(const std::string& vw_commandline)
safe_vw::safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache)
{
_vw = VW::initialize(vw_commandline);
init();
Expand Down Expand Up @@ -120,6 +122,24 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector<int>& acti
for (auto&& ex : examples) { _example_pool.emplace_back(ex); }
}

void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str)
{
if (_dedup_cache == nullptr) { _dedup_cache = new lru_dedup_cache(); }
VW::multi_ex examples;
examples.push_back(get_or_create_example());

if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
else
{
VW::read_line_json_s<false>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
_dedup_cache->add(hash, examples[0]);
}

void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<float>& scores)
{
VW::multi_ex examples;
Expand All @@ -131,9 +151,14 @@ void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<f
if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this);
VW::read_line_json_s<true>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict());
}
else
{
VW::read_line_json_s<false>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict());
}
else { VW::read_line_json_s<false>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); }

// finalize example
VW::setup_examples(*_vw, examples);
Expand Down Expand Up @@ -372,19 +397,30 @@ void safe_vw::init()
}
}

safe_vw_factory::safe_vw_factory(std::string command_line) : _command_line(std::move(command_line)) {}
safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache)
: _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) : _master_data(master_data) {}
safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) : _master_data(master_data) {}
safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line)
: _master_data(master_data), _command_line(std::move(command_line))
safe_vw_factory::safe_vw_factory(
const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line)
: _master_data(master_data), _command_line(std::move(command_line))
safe_vw_factory::safe_vw_factory(
const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

Expand All @@ -393,13 +429,13 @@ safe_vw* safe_vw_factory::operator()()
if ((_master_data.data() != nullptr) && !_command_line.empty())
{
// Construct new vw object from raw model data and command line argument
return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line);
return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line, _dedup_cache);
}
if (_master_data.data() != nullptr)
{
// Construct new vw object from raw model data.
return new safe_vw(_master_data.data(), _master_data.data_sz());
return new safe_vw(_master_data.data(), _master_data.data_sz(), _dedup_cache);
}
return new safe_vw(_command_line);
return new safe_vw(_command_line, _dedup_cache);
}
} // namespace reinforcement_learning
24 changes: 15 additions & 9 deletions rlclientlib/vw_model/safe_vw.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "lru_dedup_cache.h"
#include "model_mgmt.h"
#include "vw/core/vw.h"

Expand All @@ -14,19 +15,21 @@ class safe_vw
std::shared_ptr<safe_vw> _master;
VW::workspace* _vw;
std::vector<VW::example*> _example_pool;
lru_dedup_cache* _dedup_cache;

VW::example* get_or_create_example();
static VW::example& get_or_create_example_f(void* vw);

public:
safe_vw(std::shared_ptr<safe_vw> master);
safe_vw(const char* model_data, size_t len, const std::string& vw_commandline);
safe_vw(const char* model_data, size_t len);
safe_vw(const std::string& vw_commandline);
safe_vw(std::shared_ptr<safe_vw> master, lru_dedup_cache* dedup_cache);
safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache);
safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache);
safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache);

~safe_vw();

void parse_context_with_pdf(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void add_lru_dedup_cache(uint64_t hash, std::string action_str);
void rank(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void choose_continuous_action(string_view context, float& action, float& pdf_value);
// Used for CCB
Expand Down Expand Up @@ -57,14 +60,17 @@ class safe_vw_factory
{
model_management::model_data _master_data;
std::string _command_line;
lru_dedup_cache* _dedup_cache;

public:
// model_data is copied and stored in the factory object.
safe_vw_factory(std::string command_line);
safe_vw_factory(const model_management::model_data& master_data);
safe_vw_factory(const model_management::model_data&& master_data);
safe_vw_factory(const model_management::model_data& master_data, std::string command_line);
safe_vw_factory(const model_management::model_data&& master_data, std::string command_line);
safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache);
safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache);
safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache);
safe_vw_factory(
const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache);
safe_vw_factory(
const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache);

safe_vw* operator()();
};
Expand Down
Loading