Skip to content

Commit 913c6cd

Browse files
committed
fix: Add type safety to FindMutable calls
1 parent accebe4 commit 913c6cd

File tree

3 files changed

+33
-27
lines changed

3 files changed

+33
-27
lines changed

src/server/hset_family.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,9 +1164,12 @@ void HSetFamily::HRandField(CmdArgList args, const CommandContext& cmd_cntx) {
11641164
}
11651165

11661166
if (string_map->Empty()) { // Can happen if we use a TTL on hash members.
1167-
// post_updater will run immediately
1168-
auto it = db_slice.FindMutable(db_context, key).it;
1169-
db_slice.Del(db_context, it);
1167+
// Use type-safe deletion (fixes #5316)
1168+
auto res_it = db_slice.FindMutable(db_context, key, OBJ_HASH);
1169+
if (res_it) {
1170+
res_it->post_updater.Run();
1171+
db_slice.Del(db_context, res_it->it);
1172+
}
11701173
return facade::OpStatus::KEY_NOTFOUND;
11711174
}
11721175
} else if (pv.Encoding() == kEncodingListPack) {

src/server/set_family.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,11 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const NewE
481481
// to overwrite the key. However, if the set is empty it means we should delete the
482482
// key if it exists.
483483
if (overwrite && (vals_it.begin() == vals_it.end())) {
484-
auto it = db_slice.FindMutable(op_args.db_cntx, key).it; // post_updater will run immediately
485-
if (IsValid(it)) {
486-
db_slice.Del(op_args.db_cntx, it);
484+
// Use type-safe deletion with OBJ_SET (fixes #5316)
485+
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
486+
if (res_it) {
487+
res_it->post_updater.Run();
488+
db_slice.Del(op_args.db_cntx, res_it->it);
487489
if (journal_update && op_args.shard->journal()) {
488490
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
489491
}

src/server/string_family.cc

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,13 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
285285
auto& db_slice = op_args.GetDbSlice();
286286

287287
// we avoid using AddOrFind because of skip_on_missing option for memcache.
288-
auto res = db_slice.FindMutable(op_args.db_cntx, key);
288+
// Use type-safe FindMutable with OBJ_STRING (fixes #5316)
289+
auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);
290+
291+
if (!res) {
292+
if (res.status() == OpStatus::WRONG_TYPE)
293+
return res.status();
289294

290-
if (!IsValid(res.it)) {
291295
if (skip_on_missing)
292296
return OpStatus::KEY_NOTFOUND;
293297

@@ -300,11 +304,8 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
300304
return incr;
301305
}
302306

303-
if (res.it->second.ObjType() != OBJ_STRING) {
304-
return OpStatus::WRONG_TYPE;
305-
}
306-
307-
auto opt_prev = res.it->second.TryGetInt();
307+
// Type is already checked by FindMutable (OBJ_STRING)
308+
auto opt_prev = res->it->second.TryGetInt();
308309
if (!opt_prev) {
309310
return OpStatus::INVALID_VALUE;
310311
}
@@ -316,8 +317,8 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, string_view key, int64_t incr,
316317
}
317318

318319
int64_t new_val = prev + incr;
319-
DCHECK(!res.it->second.IsExternal());
320-
res.it->second.SetInt(new_val);
320+
DCHECK(!res->it->second.IsExternal());
321+
res->it->second.SetInt(new_val);
321322

322323
return new_val;
323324
}
@@ -383,20 +384,20 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
383384
// Cost of this request
384385
const int64_t increment_ns = emission_interval_ns * quantity; // should be nonnegative
385386

386-
auto res = db_slice.FindMutable(op_args.db_cntx, key);
387+
// Use type-safe FindMutable with OBJ_STRING (fixes #5316)
388+
auto res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);
387389
const int64_t now_ns = GetCurrentTimeNs();
388390

389391
int64_t tat_ns = now_ns;
390-
if (IsValid(res.it)) {
391-
if (res.it->second.ObjType() != OBJ_STRING) {
392-
return OpStatus::WRONG_TYPE;
393-
}
394-
395-
auto opt_prev = res.it->second.TryGetInt();
392+
if (res) {
393+
// Type is already checked by FindMutable (OBJ_STRING)
394+
auto opt_prev = res->it->second.TryGetInt();
396395
if (!opt_prev) {
397396
return OpStatus::INVALID_VALUE;
398397
}
399398
tat_ns = *opt_prev;
399+
} else if (res.status() == OpStatus::WRONG_TYPE) {
400+
return res.status();
400401
}
401402

402403
int64_t new_tat_ns = max(tat_ns, now_ns);
@@ -458,14 +459,14 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
458459
// break behavior because the tat_ns value will be used to check for throttling.
459460
const int64_t new_tat_ms =
460461
(new_tat_ns + kMilliSecondToNanoSecond - 1) / kMilliSecondToNanoSecond;
461-
if (IsValid(res.it)) {
462-
if (IsValid(res.exp_it)) {
463-
res.exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
462+
if (res) {
463+
if (IsValid(res->exp_it)) {
464+
res->exp_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
464465
} else {
465-
db_slice.AddExpire(op_args.db_cntx.db_index, res.it, new_tat_ms);
466+
db_slice.AddExpire(op_args.db_cntx.db_index, res->it, new_tat_ms);
466467
}
467468

468-
res.it->second.SetInt(new_tat_ns);
469+
res->it->second.SetInt(new_tat_ns);
469470
} else {
470471
CompactObj cobj;
471472
cobj.SetInt(new_tat_ns);

0 commit comments

Comments
 (0)