Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/server/hset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1164,9 +1164,11 @@ void HSetFamily::HRandField(CmdArgList args, const CommandContext& cmd_cntx) {
}

if (string_map->Empty()) { // Can happen if we use a TTL on hash members.
// post_updater will run immediately
auto it = db_slice.FindMutable(db_context, key).it;
db_slice.Del(db_context, it);
auto res_it = db_slice.FindMutable(db_context, key, OBJ_HASH);
if (res_it) {
res_it->post_updater.Run();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to run it manually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoUpdater dtor automatically calls Run() on destruction, which has a DCHECK that verifies the key still exists in the database.
Since we call db_slice.Del() earlier in the code (which removes the key), we must manually call post_updater.Run() before deletion to update memory accounting while the key is still valid. Otherwise, when res_it goes out of scope, the destructor will call Run() on an already-deleted key and the DCHECK will fail.
This pattern is used throughout the codebase whenever a key is deleted after FindMutable().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can unify somehow this code, I mean do delete in post_updater or call run during delete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern (post_updater.Run() + Del()) appears many times across several files in the codebase, so unifying it makes sense.

Suggested approach:
Add a new method DelMutable() to DbSlice:

void DbSlice::DelMutable(Context cntx, ItAndUpdater& it_updater) {
  it_updater.post_updater.Run();
  Del(cntx, it_updater.it);
}

This would simplify code like:

// From
res_it->post_updater.Run();
db_slice.Del(op_args.db_cntx, res_it->it);

// To
db_slice.DelMutable(op_args.db_cntx, *res_it);

However, I'd suggest doing this as a separate follow-up PR after this one, because:

  • This PR focuses on type safety fixes
  • Refactoring many call sites is a larger change that deserves separate review
  • Mixing two different goals increases risk and review complexity

Maybe better create a separate issue for this refactoring?

db_slice.Del(db_context, res_it->it);
}
return facade::OpStatus::KEY_NOTFOUND;
}
} else if (pv.Encoding() == kEncodingListPack) {
Expand Down
7 changes: 4 additions & 3 deletions src/server/set_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const NewE
// to overwrite the key. However, if the set is empty it means we should delete the
// key if it exists.
if (overwrite && (vals_it.begin() == vals_it.end())) {
auto it = db_slice.FindMutable(op_args.db_cntx, key).it; // post_updater will run immediately
if (IsValid(it)) {
db_slice.Del(op_args.db_cntx, it);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
if (res_it) {
res_it->post_updater.Run();
db_slice.Del(op_args.db_cntx, res_it->it);
if (journal_update && op_args.shard->journal()) {
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
}
Expand Down
41 changes: 20 additions & 21 deletions src/server/string_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,12 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
auto& db_slice = op_args.GetDbSlice();

// we avoid using AddOrFind because of skip_on_missing option for memcache.
auto res = db_slice.FindMutable(op_args.db_cntx, key);
auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);

if (!res) {
if (res.status() == OpStatus::WRONG_TYPE)
return res.status();

if (!IsValid(res.it)) {
if (skip_on_missing)
return OpStatus::KEY_NOTFOUND;

Expand All @@ -300,11 +303,8 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
return incr;
}

if (res.it->second.ObjType() != OBJ_STRING) {
return OpStatus::WRONG_TYPE;
}

auto opt_prev = res.it->second.TryGetInt();
// Type is already checked by FindMutable (OBJ_STRING)
auto opt_prev = res->it->second.TryGetInt();
if (!opt_prev) {
return OpStatus::INVALID_VALUE;
}
Expand All @@ -316,8 +316,8 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
}

int64_t new_val = prev + incr;
DCHECK(!res.it->second.IsExternal());
res.it->second.SetInt(new_val);
DCHECK(!res->it->second.IsExternal());
res->it->second.SetInt(new_val);

return new_val;
}
Expand Down Expand Up @@ -383,20 +383,19 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
// Cost of this request
const int64_t increment_ns = emission_interval_ns * quantity; // should be nonnegative

auto res = db_slice.FindMutable(op_args.db_cntx, key);
auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);
const int64_t now_ns = GetCurrentTimeNs();

int64_t tat_ns = now_ns;
if (IsValid(res.it)) {
if (res.it->second.ObjType() != OBJ_STRING) {
return OpStatus::WRONG_TYPE;
}

auto opt_prev = res.it->second.TryGetInt();
if (res) {
// Type is already checked by FindMutable (OBJ_STRING)
auto opt_prev = res->it->second.TryGetInt();
if (!opt_prev) {
return OpStatus::INVALID_VALUE;
}
tat_ns = *opt_prev;
} else if (res.status() == OpStatus::WRONG_TYPE) {
return res.status();
}

int64_t new_tat_ns = max(tat_ns, now_ns);
Expand Down Expand Up @@ -458,14 +457,14 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
// break behavior because the tat_ns value will be used to check for throttling.
const int64_t new_tat_ms =
(new_tat_ns + kMilliSecondToNanoSecond - 1) / kMilliSecondToNanoSecond;
if (IsValid(res.it)) {
if (IsValid(res.exp_it)) {
res.exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
if (res) {
if (IsValid(res->exp_it)) {
res->exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
} else {
db_slice.AddExpire(op_args.db_cntx.db_index, res.it, new_tat_ms);
db_slice.AddExpire(op_args.db_cntx.db_index, res->it, new_tat_ms);
}

res.it->second.SetInt(new_tat_ns);
res->it->second.SetInt(new_tat_ns);
} else {
CompactObj cobj;
cobj.SetInt(new_tat_ns);
Expand Down
Loading