diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 59c46c32afd4..a59a1bd6d2a1 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1164,9 +1164,11 @@ void HSetFamily::HRandField(CmdArgList args, const CommandContext& cmd_cntx) { } if (string_map->Empty()) { // Can happen if we use a TTL on hash members. - // post_updater will run immediately - auto it = db_slice.FindMutable(db_context, key).it; - db_slice.Del(db_context, it); + auto res_it = db_slice.FindMutable(db_context, key, OBJ_HASH); + if (res_it) { + res_it->post_updater.Run(); + db_slice.Del(db_context, res_it->it); + } return facade::OpStatus::KEY_NOTFOUND; } } else if (pv.Encoding() == kEncodingListPack) { diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 605a20c6d992..feee9dae7689 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -481,9 +481,10 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const NewE // to overwrite the key. However, if the set is empty it means we should delete the // key if it exists. if (overwrite && (vals_it.begin() == vals_it.end())) { - auto it = db_slice.FindMutable(op_args.db_cntx, key).it; // post_updater will run immediately - if (IsValid(it)) { - db_slice.Del(op_args.db_cntx, it); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET); + if (res_it) { + res_it->post_updater.Run(); + db_slice.Del(op_args.db_cntx, res_it->it); if (journal_update && op_args.shard->journal()) { RecordJournal(op_args, "DEL"sv, ArgSlice{key}); } diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 6037e1af5978..197ac312643e 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -285,9 +285,12 @@ OpResult OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr, auto& db_slice = op_args.GetDbSlice(); // we avoid using AddOrFind because of skip_on_missing option for memcache. - auto res = db_slice.FindMutable(op_args.db_cntx, key); + auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING); + + if (!res) { + if (res.status() == OpStatus::WRONG_TYPE) + return res.status(); - if (!IsValid(res.it)) { if (skip_on_missing) return OpStatus::KEY_NOTFOUND; @@ -300,11 +303,8 @@ OpResult OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr, return incr; } - if (res.it->second.ObjType() != OBJ_STRING) { - return OpStatus::WRONG_TYPE; - } - - auto opt_prev = res.it->second.TryGetInt(); + // Type is already checked by FindMutable (OBJ_STRING) + auto opt_prev = res->it->second.TryGetInt(); if (!opt_prev) { return OpStatus::INVALID_VALUE; } @@ -316,8 +316,8 @@ OpResult OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr, } int64_t new_val = prev + incr; - DCHECK(!res.it->second.IsExternal()); - res.it->second.SetInt(new_val); + DCHECK(!res->it->second.IsExternal()); + res->it->second.SetInt(new_val); return new_val; } @@ -383,20 +383,19 @@ OpResult> OpThrottle(const OpArgs& op_args, const string_view // Cost of this request const int64_t increment_ns = emission_interval_ns * quantity; // should be nonnegative - auto res = db_slice.FindMutable(op_args.db_cntx, key); + auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING); const int64_t now_ns = GetCurrentTimeNs(); int64_t tat_ns = now_ns; - if (IsValid(res.it)) { - if (res.it->second.ObjType() != OBJ_STRING) { - return OpStatus::WRONG_TYPE; - } - - auto opt_prev = res.it->second.TryGetInt(); + if (res) { + // Type is already checked by FindMutable (OBJ_STRING) + auto opt_prev = res->it->second.TryGetInt(); if (!opt_prev) { return OpStatus::INVALID_VALUE; } tat_ns = *opt_prev; + } else if (res.status() == OpStatus::WRONG_TYPE) { + return res.status(); } int64_t new_tat_ns = max(tat_ns, now_ns); @@ -458,14 +457,14 @@ OpResult> OpThrottle(const OpArgs& op_args, const string_view // break behavior because the tat_ns value will be used to check for throttling. const int64_t new_tat_ms = (new_tat_ns + kMilliSecondToNanoSecond - 1) / kMilliSecondToNanoSecond; - if (IsValid(res.it)) { - if (IsValid(res.exp_it)) { - res.exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms); + if (res) { + if (IsValid(res->exp_it)) { + res->exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms); } else { - db_slice.AddExpire(op_args.db_cntx.db_index, res.it, new_tat_ms); + db_slice.AddExpire(op_args.db_cntx.db_index, res->it, new_tat_ms); } - res.it->second.SetInt(new_tat_ns); + res->it->second.SetInt(new_tat_ns); } else { CompactObj cobj; cobj.SetInt(new_tat_ns);