Skip to content

Commit fbc47ff

Browse files
authored
fix: Add type safety to FindMutable calls (#5886)
Fixes: #5316
1 parent cc0c7e5 commit fbc47ff

File tree

3 files changed

+29
-27
lines changed

3 files changed

+29
-27
lines changed

src/server/hset_family.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,9 +1164,11 @@ void HSetFamily::HRandField(CmdArgList args, const CommandContext& cmd_cntx) {
11641164
}
11651165

11661166
if (string_map->Empty()) { // Can happen if we use a TTL on hash members.
1167-
// post_updater will run immediately
1168-
auto it = db_slice.FindMutable(db_context, key).it;
1169-
db_slice.Del(db_context, it);
1167+
auto res_it = db_slice.FindMutable(db_context, key, OBJ_HASH);
1168+
if (res_it) {
1169+
res_it->post_updater.Run();
1170+
db_slice.Del(db_context, res_it->it);
1171+
}
11701172
return facade::OpStatus::KEY_NOTFOUND;
11711173
}
11721174
} else if (pv.Encoding() == kEncodingListPack) {

src/server/set_family.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,10 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const NewE
481481
// to overwrite the key. However, if the set is empty it means we should delete the
482482
// key if it exists.
483483
if (overwrite && (vals_it.begin() == vals_it.end())) {
484-
auto it = db_slice.FindMutable(op_args.db_cntx, key).it; // post_updater will run immediately
485-
if (IsValid(it)) {
486-
db_slice.Del(op_args.db_cntx, it);
484+
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
485+
if (res_it) {
486+
res_it->post_updater.Run();
487+
db_slice.Del(op_args.db_cntx, res_it->it);
487488
if (journal_update && op_args.shard->journal()) {
488489
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
489490
}

src/server/string_family.cc

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,12 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
285285
auto& db_slice = op_args.GetDbSlice();
286286

287287
// we avoid using AddOrFind because of skip_on_missing option for memcache.
288-
auto res = db_slice.FindMutable(op_args.db_cntx, key);
288+
auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);
289+
290+
if (!res) {
291+
if (res.status() == OpStatus::WRONG_TYPE)
292+
return res.status();
289293

290-
if (!IsValid(res.it)) {
291294
if (skip_on_missing)
292295
return OpStatus::KEY_NOTFOUND;
293296

@@ -300,11 +303,8 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
300303
return incr;
301304
}
302305

303-
if (res.it->second.ObjType() != OBJ_STRING) {
304-
return OpStatus::WRONG_TYPE;
305-
}
306-
307-
auto opt_prev = res.it->second.TryGetInt();
306+
// Type is already checked by FindMutable (OBJ_STRING)
307+
auto opt_prev = res->it->second.TryGetInt();
308308
if (!opt_prev) {
309309
return OpStatus::INVALID_VALUE;
310310
}
@@ -316,8 +316,8 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
316316
}
317317

318318
int64_t new_val = prev + incr;
319-
DCHECK(!res.it->second.IsExternal());
320-
res.it->second.SetInt(new_val);
319+
DCHECK(!res->it->second.IsExternal());
320+
res->it->second.SetInt(new_val);
321321

322322
return new_val;
323323
}
@@ -383,20 +383,19 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
383383
// Cost of this request
384384
const int64_t increment_ns = emission_interval_ns * quantity; // should be nonnegative
385385

386-
auto res = db_slice.FindMutable(op_args.db_cntx, key);
386+
auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);
387387
const int64_t now_ns = GetCurrentTimeNs();
388388

389389
int64_t tat_ns = now_ns;
390-
if (IsValid(res.it)) {
391-
if (res.it->second.ObjType() != OBJ_STRING) {
392-
return OpStatus::WRONG_TYPE;
393-
}
394-
395-
auto opt_prev = res.it->second.TryGetInt();
390+
if (res) {
391+
// Type is already checked by FindMutable (OBJ_STRING)
392+
auto opt_prev = res->it->second.TryGetInt();
396393
if (!opt_prev) {
397394
return OpStatus::INVALID_VALUE;
398395
}
399396
tat_ns = *opt_prev;
397+
} else if (res.status() == OpStatus::WRONG_TYPE) {
398+
return res.status();
400399
}
401400

402401
int64_t new_tat_ns = max(tat_ns, now_ns);
@@ -458,14 +457,14 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
458457
// break behavior because the tat_ns value will be used to check for throttling.
459458
const int64_t new_tat_ms =
460459
(new_tat_ns + kMilliSecondToNanoSecond - 1) / kMilliSecondToNanoSecond;
461-
if (IsValid(res.it)) {
462-
if (IsValid(res.exp_it)) {
463-
res.exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
460+
if (res) {
461+
if (IsValid(res->exp_it)) {
462+
res->exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
464463
} else {
465-
db_slice.AddExpire(op_args.db_cntx.db_index, res.it, new_tat_ms);
464+
db_slice.AddExpire(op_args.db_cntx.db_index, res->it, new_tat_ms);
466465
}
467466

468-
res.it->second.SetInt(new_tat_ns);
467+
res->it->second.SetInt(new_tat_ns);
469468
} else {
470469
CompactObj cobj;
471470
cobj.SetInt(new_tat_ns);

0 commit comments

Comments
 (0)