diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..56b6c9b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,35 @@ +# Repository Guidelines + +## Project Structure & Modules +- `cmd/server`, `cmd/client`: entrypoints for running the KV server and client. +- `store/`: MVCC storage engine, OCC/TTL, and related tests. +- `kv/`: hybrid logical clock (HLC) utilities and KV interfaces. +- `adapter/`: protocol adapters (e.g., Redis), plus integration tests. +- `jepsen/`, `jepsen/redis/`: Jepsen test harnesses and workloads. +- `proto/`, `distribution/`, `internal/`: supporting protobufs, build assets, and shared helpers. + +## Build, Test, and Development Commands +- `go test ./...` — run unit/integration tests. If macOS sandbox blocks `$GOCACHE`, prefer `GOCACHE=$(pwd)/.cache GOTMPDIR=$(pwd)/.cache/tmp go test ./...`. +- `GOCACHE=$(pwd)/.cache GOLANGCI_LINT_CACHE=$(pwd)/.golangci-cache golangci-lint run ./... --timeout=5m` — full lint suite. +- `HOME=$(pwd)/jepsen/tmp-home LEIN_HOME=$(pwd)/jepsen/.lein LEIN_JVM_OPTS="-Duser.home=$(pwd)/jepsen/tmp-home" /tmp/lein test` (from `jepsen/` or `jepsen/redis/`) — Jepsen tests. +- `go run ./cmd/server` / `go run ./cmd/client` — start server or CLI locally. + +## Coding Style & Naming +- Go code: `gofmt` + project lint rules (`golangci-lint`). Avoid adding `//nolint` unless absolutely required; prefer refactoring. +- Naming: Go conventions (MixedCaps for exported identifiers, short receiver names). Filenames remain lowercase with underscores only where existing. +- Logging: use `slog` where present; maintain structured keys (`key`, `commit_ts`, etc.). + +## Testing Guidelines +- Unit tests co-located with packages (`*_test.go`); prefer table-driven cases. +- TTL/HLC behaviors live in `store/` and `kv/`; add coverage when touching clocks, OCC, or replication logic. +- Integration: run Jepsen suites after changes affecting replication, MVCC, or Redis adapter. + - `cd jepsen && HOME=$(pwd)/tmp-home LEIN_HOME=$(pwd)/.lein LEIN_JVM_OPTS="-Duser.home=$(pwd)/tmp-home" /tmp/lein test` + - `cd jepsen/redis && HOME=$(pwd)/../tmp-home LEIN_HOME=$(pwd)/../.lein LEIN_JVM_OPTS="-Duser.home=$(pwd)/../tmp-home" /tmp/lein test` + +## Commit & Pull Request Guidelines +- Messages: short imperative summary (e.g., “Add HLC TTL handling”). Include scope when helpful (`store:`, `adapter:`). +- Pull requests: describe behavior change, risk, and test evidence (`go test`, lint, Jepsen). Add repro steps for bug fixes. + +## Security & Configuration Tips +- Hybrid clock derives from wall-clock millis; keep system clock reasonably synchronized across nodes. +- Avoid leader-local timestamps in persistence; timestamp issuance should originate from the Raft leader to prevent skewed OCC decisions. diff --git a/adapter/internal.go b/adapter/internal.go index 04d4aac..ff6fdde 100644 --- a/adapter/internal.go +++ b/adapter/internal.go @@ -9,16 +9,18 @@ import ( "github.com/hashicorp/raft" ) -func NewInternal(txm kv.Transactional, r *raft.Raft) *Internal { +func NewInternal(txm kv.Transactional, r *raft.Raft, clock *kv.HLC) *Internal { return &Internal{ raft: r, transactionManager: txm, + clock: clock, } } type Internal struct { raft *raft.Raft transactionManager kv.Transactional + clock *kv.HLC pb.UnimplementedInternalServer } @@ -33,6 +35,19 @@ func (i *Internal) Forward(_ context.Context, req *pb.ForwardRequest) (*pb.Forwa return nil, errors.WithStack(ErrNotLeader) } + // Ensure leader issues start_ts when followers forward txn groups without it. + if req.IsTxn { + var startTs uint64 + for _, r := range req.Requests { + if r.Ts == 0 { + if startTs == 0 { + startTs = i.clock.Next() + } + r.Ts = startTs + } + } + } + r, err := i.transactionManager.Commit(req.Requests) if err != nil { return &pb.ForwardResponse{ diff --git a/adapter/test_util.go b/adapter/test_util.go index 5d023a4..9d616e4 100644 --- a/adapter/test_util.go +++ b/adapter/test_util.go @@ -322,9 +322,8 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( cfg := buildRaftConfig(n, ports) for i := 0; i < n; i++ { - st := store.NewRbMemoryStore() - trxSt := store.NewMemoryStoreDefaultTTL() - fsm := kv.NewKvFSM(st, trxSt) + st := store.NewMVCCStore() + fsm := kv.NewKvFSM(st) port := ports[i] grpcSock := lis[i].grpc @@ -351,7 +350,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( tm.Register(s) pb.RegisterRawKVServer(s, gs) pb.RegisterTransactionalKVServer(s, gs) - pb.RegisterInternalServer(s, NewInternal(trx, r)) + pb.RegisterInternalServer(s, NewInternal(trx, r, coordinator.Clock())) leaderhealth.Setup(r, s, []string{"Example"}) raftadmin.Register(s, r) diff --git a/cmd/server/demo.go b/cmd/server/demo.go index bc768ea..2e0c0c6 100644 --- a/cmd/server/demo.go +++ b/cmd/server/demo.go @@ -81,9 +81,8 @@ func run(eg *errgroup.Group) error { } for i := 0; i < 3; i++ { - st := store.NewRbMemoryStore() - trxSt := store.NewMemoryStoreDefaultTTL() - fsm := kv.NewKvFSM(st, trxSt) + st := store.NewMVCCStore() + fsm := kv.NewKvFSM(st) r, tm, err := newRaft(strconv.Itoa(i), grpcAdders[i], fsm, i == 0, cfg) if err != nil { @@ -97,7 +96,7 @@ func run(eg *errgroup.Group) error { tm.Register(s) pb.RegisterRawKVServer(s, gs) pb.RegisterTransactionalKVServer(s, gs) - pb.RegisterInternalServer(s, adapter.NewInternal(trx, r)) + pb.RegisterInternalServer(s, adapter.NewInternal(trx, r, coordinator.Clock())) leaderhealth.Setup(r, s, []string{"RawKV"}) raftadmin.Register(s, r) diff --git a/go.mod b/go.mod index 53a98c5..33b2640 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/tidwall/redcon v1.6.2 go.etcd.io/bbolt v1.4.3 golang.org/x/sync v0.19.0 + golang.org/x/sys v0.38.0 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 ) @@ -68,7 +69,6 @@ require ( github.com/tidwall/btree v1.1.0 // indirect github.com/tidwall/match v1.1.1 // indirect golang.org/x/net v0.47.0 // indirect - golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/kv/coordinator.go b/kv/coordinator.go index 0031295..a85d52a 100644 --- a/kv/coordinator.go +++ b/kv/coordinator.go @@ -14,6 +14,7 @@ func NewCoordinator(txm Transactional, r *raft.Raft) *Coordinate { return &Coordinate{ transactionManager: txm, raft: r, + clock: NewHLC(), } } @@ -24,6 +25,7 @@ type CoordinateResponse struct { type Coordinate struct { transactionManager Transactional raft *raft.Raft + clock *HLC } var _ Coordinator = (*Coordinate)(nil) @@ -39,8 +41,13 @@ func (c *Coordinate) Dispatch(reqs *OperationGroup[OP]) (*CoordinateResponse, er return c.redirect(reqs) } + if reqs.IsTxn && reqs.StartTS == 0 { + // Leader-only timestamp issuance to avoid divergence across shards. + reqs.StartTS = c.nextStartTS() + } + if reqs.IsTxn { - return c.dispatchTxn(reqs.Elems) + return c.dispatchTxn(reqs.Elems, reqs.StartTS) } return c.dispatchRaw(reqs.Elems) @@ -56,10 +63,18 @@ func (c *Coordinate) RaftLeader() raft.ServerAddress { return addr } -func (c *Coordinate) dispatchTxn(reqs []*Elem[OP]) (*CoordinateResponse, error) { +func (c *Coordinate) Clock() *HLC { + return c.clock +} + +func (c *Coordinate) nextStartTS() uint64 { + return c.clock.Next() +} + +func (c *Coordinate) dispatchTxn(reqs []*Elem[OP], startTS uint64) (*CoordinateResponse, error) { var logs []*pb.Request for _, req := range reqs { - m := c.toTxnRequests(req) + m := c.toTxnRequests(req, startTS) logs = append(logs, m...) } @@ -96,6 +111,7 @@ func (c *Coordinate) toRawRequest(req *Elem[OP]) *pb.Request { return &pb.Request{ IsTxn: false, Phase: pb.Phase_NONE, + Ts: c.clock.Next(), Mutations: []*pb.Mutation{ { Op: pb.Op_PUT, @@ -109,6 +125,7 @@ func (c *Coordinate) toRawRequest(req *Elem[OP]) *pb.Request { return &pb.Request{ IsTxn: false, Phase: pb.Phase_NONE, + Ts: c.clock.Next(), Mutations: []*pb.Mutation{ { Op: pb.Op_DEL, @@ -121,16 +138,14 @@ func (c *Coordinate) toRawRequest(req *Elem[OP]) *pb.Request { panic("unreachable") } -const defaultTxnLockTTLSeconds = uint64(30) - -func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { +func (c *Coordinate) toTxnRequests(req *Elem[OP], startTS uint64) []*pb.Request { switch req.Op { case Put: return []*pb.Request{ { IsTxn: true, Phase: pb.Phase_PREPARE, - Ts: defaultTxnLockTTLSeconds, + Ts: startTS, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -141,7 +156,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_COMMIT, - Ts: defaultTxnLockTTLSeconds, + Ts: startTS, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -156,7 +171,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_PREPARE, - Ts: defaultTxnLockTTLSeconds, + Ts: startTS, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -166,7 +181,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_COMMIT, - Ts: defaultTxnLockTTLSeconds, + Ts: startTS, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -204,7 +219,7 @@ func (c *Coordinate) redirect(reqs *OperationGroup[OP]) (*CoordinateResponse, er var requests []*pb.Request if reqs.IsTxn { for _, req := range reqs.Elems { - requests = append(requests, c.toTxnRequests(req)...) + requests = append(requests, c.toTxnRequests(req, reqs.StartTS)...) } } else { for _, req := range reqs.Elems { diff --git a/kv/fsm.go b/kv/fsm.go index 63d34ff..aeb0ab5 100644 --- a/kv/fsm.go +++ b/kv/fsm.go @@ -2,13 +2,10 @@ package kv import ( "context" - "encoding/binary" "io" "log/slog" "os" - "time" - "github.com/bootjp/elastickv/internal" pb "github.com/bootjp/elastickv/proto" "github.com/bootjp/elastickv/store" "github.com/cockroachdb/errors" @@ -17,19 +14,17 @@ import ( ) type kvFSM struct { - store store.Store - lockStore store.TTLStore - log *slog.Logger + store store.MVCCStore + log *slog.Logger } type FSM interface { raft.FSM } -func NewKvFSM(store store.Store, lockStore store.TTLStore) FSM { +func NewKvFSM(store store.MVCCStore) FSM { return &kvFSM{ - store: store, - lockStore: lockStore, + store: store, log: slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelWarn, })), @@ -50,7 +45,7 @@ func (f *kvFSM) Apply(l *raft.Log) interface{} { return errors.WithStack(err) } - err = f.handleRequest(ctx, r) + err = f.handleRequest(ctx, r, r.Ts) if err != nil { return errors.WithStack(err) } @@ -58,34 +53,23 @@ func (f *kvFSM) Apply(l *raft.Log) interface{} { return nil } -func (f *kvFSM) handleRequest(ctx context.Context, r *pb.Request) error { +func (f *kvFSM) handleRequest(ctx context.Context, r *pb.Request, commitTS uint64) error { switch { case r.IsTxn: - return f.handleTxnRequest(ctx, r) + return f.handleTxnRequest(ctx, r, commitTS) default: - return f.handleRawRequest(ctx, r) + return f.handleRawRequest(ctx, r, commitTS) } } -func (f *kvFSM) handleRawRequest(ctx context.Context, r *pb.Request) error { - for _, mut := range r.Mutations { - switch mut.Op { - case pb.Op_PUT: - err := f.store.Put(ctx, mut.Key, mut.Value) - if err != nil { - return errors.WithStack(err) - } - case pb.Op_DEL: - err := f.store.Delete(ctx, mut.Key) - if err != nil { - return errors.WithStack(err) - } - default: - return errors.WithStack(ErrUnknownRequestType) - } +func (f *kvFSM) handleRawRequest(ctx context.Context, r *pb.Request, commitTS uint64) error { + muts, err := toStoreMutations(r.Mutations) + if err != nil { + return errors.WithStack(err) } - - return nil + // Raw requests always commit against the latest state; use commitTS as both + // the validation snapshot and the commit timestamp. + return errors.WithStack(f.store.ApplyMutations(ctx, muts, commitTS, commitTS)) } var ErrNotImplemented = errors.New("not implemented") @@ -106,12 +90,12 @@ func (f *kvFSM) Restore(r io.ReadCloser) error { return errors.WithStack(f.store.Restore(r)) } -func (f *kvFSM) handleTxnRequest(ctx context.Context, r *pb.Request) error { +func (f *kvFSM) handleTxnRequest(ctx context.Context, r *pb.Request, commitTS uint64) error { switch r.Phase { case pb.Phase_PREPARE: return f.handlePrepareRequest(ctx, r) case pb.Phase_COMMIT: - return f.handleCommitRequest(ctx, r) + return f.handleCommitRequest(ctx, r, commitTS) case pb.Phase_ABORT: return f.handleAbortRequest(ctx, r) case pb.Phase_NONE: @@ -122,121 +106,54 @@ func (f *kvFSM) handleTxnRequest(ctx context.Context, r *pb.Request) error { } } -var ErrKeyAlreadyLocked = errors.New("key already locked") -var ErrKeyNotLocked = errors.New("key not locked") - -func (f *kvFSM) hasLock(txn store.Txn, key []byte) (bool, error) { - //nolint:wrapcheck - return internal.WithStacks(txn.Exists(context.Background(), key)) -} -func (f *kvFSM) lock(txn store.TTLTxn, key []byte, ttl uint64) error { - ittl, err := internal.Uint64ToInt64(ttl) - if err != nil { - return errors.WithStack(err) - } - //nolint:mnd - b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, ttl) - expire := time.Now().Unix() + ittl - return errors.WithStack(txn.PutWithTTL(context.Background(), key, b, expire)) -} - -func (f *kvFSM) unlock(txn store.Txn, key []byte) error { - return errors.WithStack(txn.Delete(context.Background(), key)) +func (f *kvFSM) validateConflicts(ctx context.Context, muts []*pb.Mutation, startTS uint64) error { + // Debug guard only: real OCC runs at the leader/storage layer, so conflicts + // should already be resolved before log application. Keep this stub to make + // any unexpected violations visible during development. + return nil } func (f *kvFSM) handlePrepareRequest(ctx context.Context, r *pb.Request) error { - err := f.lockStore.TxnWithTTL(ctx, func(ctx context.Context, txn store.TTLTxn) error { - for _, mut := range r.Mutations { - if exist, _ := txn.Exists(ctx, mut.Key); exist { - return errors.WithStack(ErrKeyAlreadyLocked) - } - //nolint:mnd - err := f.lock(txn, mut.Key, r.Ts) - if err != nil { - return errors.WithStack(err) - } - } - return nil - }) + err := f.validateConflicts(ctx, r.Mutations, r.Ts) f.log.InfoContext(ctx, "handlePrepareRequest finish") - return errors.WithStack(err) } -func (f *kvFSM) commit(ctx context.Context, txn store.Txn, mut *pb.Mutation) error { - switch mut.Op { - case pb.Op_PUT: - err := txn.Put(ctx, mut.Key, mut.Value) - if err != nil { - return errors.WithStack(err) - } - case pb.Op_DEL: - err := txn.Delete(ctx, mut.Key) - if err != nil { - return errors.WithStack(err) - } +func (f *kvFSM) handleCommitRequest(ctx context.Context, r *pb.Request, commitTS uint64) error { + muts, err := toStoreMutations(r.Mutations) + if err != nil { + return errors.WithStack(err) + } + if err := f.validateConflicts(ctx, r.Mutations, r.Ts); err != nil { + return errors.WithStack(err) } - return nil + return errors.WithStack(f.store.ApplyMutations(ctx, muts, r.Ts, commitTS)) } -func (f *kvFSM) handleCommitRequest(ctx context.Context, r *pb.Request) error { - // Release locks regardless of success or failure - defer func() { - err := f.lockStore.Txn(ctx, func(ctx context.Context, txn store.Txn) error { - for _, mut := range r.Mutations { - err := f.unlock(txn, mut.Key) - if err != nil { - return errors.WithStack(err) - } - } - return nil - }) - - if err != nil { - f.log.ErrorContext(ctx, "unlock error", slog.String("err", err.Error())) - } - }() - - err := f.lockStore.Txn(ctx, func(ctx context.Context, lockTxn store.Txn) error { - err := f.store.Txn(ctx, func(ctx context.Context, txn store.Txn) error { - // commit - for _, mut := range r.Mutations { - ok, err := f.hasLock(lockTxn, mut.Key) - if err != nil { - return errors.WithStack(err) - } - - if !ok { - // Lock already gone: treat as conflict and abort. - return errors.WithStack(ErrKeyNotLocked) - } - - err = f.commit(ctx, txn, mut) - if err != nil { - return errors.WithStack(err) - } - } - return nil - }) - - return errors.WithStack(err) - }) - - return errors.WithStack(err) +func (f *kvFSM) handleAbortRequest(_ context.Context, _ *pb.Request) error { + // OCC does not rely on locks; abort is a no-op. + return nil } -func (f *kvFSM) handleAbortRequest(ctx context.Context, r *pb.Request) error { - err := f.lockStore.Txn(ctx, func(ctx context.Context, txn store.Txn) error { - for _, mut := range r.Mutations { - err := txn.Delete(ctx, mut.Key) - if err != nil { - return errors.WithStack(err) - } +func toStoreMutations(muts []*pb.Mutation) ([]*store.KVPairMutation, error) { + out := make([]*store.KVPairMutation, 0, len(muts)) + for _, mut := range muts { + switch mut.Op { + case pb.Op_PUT: + out = append(out, &store.KVPairMutation{ + Op: store.OpTypePut, + Key: mut.Key, + Value: mut.Value, + }) + case pb.Op_DEL: + out = append(out, &store.KVPairMutation{ + Op: store.OpTypeDelete, + Key: mut.Key, + }) + default: + return nil, ErrUnknownRequestType } - return nil - }) - - return errors.WithStack(err) + } + return out, nil } diff --git a/kv/hlc.go b/kv/hlc.go new file mode 100644 index 0000000..53d07df --- /dev/null +++ b/kv/hlc.go @@ -0,0 +1,90 @@ +package kv + +import ( + "math" + "sync" + "time" +) + +const hlcLogicalBits = 16 +const hlcLogicalMask uint64 = (1 << hlcLogicalBits) - 1 + +// HLC implements a simple hybrid logical clock suitable for issuing +// monotonically increasing timestamps across shards/raft groups. +// +// Layout (ms logical): +// +// high 48 bits: wall clock milliseconds since Unix epoch +// low 16 bits : logical counter to break ties when wall time does not advance +// +// This keeps ordering stable across leaders as long as clocks are loosely +// synchronized; it avoids dependence on per-raft commit indices that diverge +// between shards. +type HLC struct { + mu sync.Mutex + lastWall int64 + logical uint16 +} + +func nonNegativeUint64(v int64) uint64 { + if v < 0 { + return 0 + } + return uint64(v) +} + +func clampUint64ToInt64(v uint64) int64 { + if v > math.MaxInt64 { + return math.MaxInt64 + } + return int64(v) +} + +func clampUint64ToUint16(v uint64) uint16 { + max := uint64(^uint16(0)) + if v > max { + return uint16(max) + } + return uint16(v) +} + +func NewHLC() *HLC { + return &HLC{} +} + +// Next returns the next hybrid logical timestamp. +func (h *HLC) Next() uint64 { + h.mu.Lock() + defer h.mu.Unlock() + + now := time.Now().UnixMilli() + if now > h.lastWall { + h.lastWall = now + h.logical = 0 + } else { + h.logical++ + if h.logical == 0 { // overflow; bump wall to keep monotonicity + h.lastWall++ + } + } + + wall := nonNegativeUint64(h.lastWall) + return (wall << hlcLogicalBits) | uint64(h.logical) +} + +// Observe bumps the local clock if a higher timestamp is seen. +func (h *HLC) Observe(ts uint64) { + wallPart := ts >> hlcLogicalBits + logicalPart := ts & hlcLogicalMask + + h.mu.Lock() + defer h.mu.Unlock() + + wall := clampUint64ToInt64(wallPart) + logical := clampUint64ToUint16(logicalPart) + + if wall > h.lastWall || (wall == h.lastWall && logical > h.logical) { + h.lastWall = wall + h.logical = logical + } +} diff --git a/kv/shard_router_test.go b/kv/shard_router_test.go index ad65eb8..7016e4d 100644 --- a/kv/shard_router_test.go +++ b/kv/shard_router_test.go @@ -87,7 +87,7 @@ func initTestRafts(t *testing.T, cfg raft.Configuration, trans []*raft.InmemTran if i == 0 { rfsm = fsm } else { - rfsm = NewKvFSM(store.NewRbMemoryStore(), store.NewRbMemoryStoreWithExpire(time.Minute)) + rfsm = NewKvFSM(store.NewMVCCStore()) } r, err := raft.NewRaft(c, rfsm, ldb, sdb, fss, trans[i]) if err != nil { @@ -127,16 +127,14 @@ func TestShardRouterCommit(t *testing.T) { router := NewShardRouter(e) // group 1 - s1 := store.NewRbMemoryStore() - l1 := store.NewRbMemoryStoreWithExpire(time.Minute) - r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1)) + s1 := store.NewMVCCStore() + r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1)) defer stop1() router.Register(1, NewTransaction(r1), s1) // group 2 - s2 := store.NewRbMemoryStore() - l2 := store.NewRbMemoryStoreWithExpire(time.Minute) - r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2)) + s2 := store.NewMVCCStore() + r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2)) defer stop2() router.Register(2, NewTransaction(r2), s2) @@ -169,16 +167,14 @@ func TestShardRouterSplitAndMerge(t *testing.T) { router := NewShardRouter(e) // group 1 - s1 := store.NewRbMemoryStore() - l1 := store.NewRbMemoryStoreWithExpire(time.Minute) - r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1, l1)) + s1 := store.NewMVCCStore() + r1, stop1 := newTestRaft(t, "1", NewKvFSM(s1)) defer stop1() router.Register(1, NewTransaction(r1), s1) // group 2 (will be used after split) - s2 := store.NewRbMemoryStore() - l2 := store.NewRbMemoryStoreWithExpire(time.Minute) - r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2, l2)) + s2 := store.NewMVCCStore() + r2, stop2 := newTestRaft(t, "2", NewKvFSM(s2)) defer stop2() router.Register(2, NewTransaction(r2), s2) diff --git a/kv/snapshot_test.go b/kv/snapshot_test.go index 19e4003..89ab872 100644 --- a/kv/snapshot_test.go +++ b/kv/snapshot_test.go @@ -12,9 +12,8 @@ import ( ) func TestSnapshot(t *testing.T) { - store := store3.NewMemoryStore() - txnStore := store3.NewMemoryStoreDefaultTTL() - fsm := NewKvFSM(store, txnStore) + store := store3.NewMVCCStore() + fsm := NewKvFSM(store) mutation := pb.Request{ IsTxn: false, @@ -47,9 +46,8 @@ func TestSnapshot(t *testing.T) { snapshot, err := fsm.Snapshot() assert.NoError(t, err) - store2 := store3.NewMemoryStore() - trxnStore2 := store3.NewMemoryStoreDefaultTTL() - fsm2 := NewKvFSM(store2, trxnStore2) + store2 := store3.NewMVCCStore() + fsm2 := NewKvFSM(store2) kvFSMSnap, ok := snapshot.(*kvFSMSnapshot) assert.True(t, ok) diff --git a/kv/transcoder.go b/kv/transcoder.go index 6ea563c..837232a 100644 --- a/kv/transcoder.go +++ b/kv/transcoder.go @@ -20,4 +20,7 @@ type Elem[T OP] struct { type OperationGroup[T OP] struct { Elems []*Elem[T] IsTxn bool + // StartTS is a logical timestamp captured at transaction begin. + // It is ignored for non-transactional groups. + StartTS uint64 } diff --git a/main.go b/main.go index b56b2ae..034576c 100644 --- a/main.go +++ b/main.go @@ -53,9 +53,8 @@ func main() { log.Fatalf("failed to listen: %v", err) } - s := store.NewRbMemoryStore() - lockStore := store.NewMemoryStoreDefaultTTL() - kvFSM := kv.NewKvFSM(s, lockStore) + s := store.NewMVCCStore() + kvFSM := kv.NewKvFSM(s) r, tm, err := NewRaft(ctx, *raftId, *myAddr, kvFSM) if err != nil { @@ -67,7 +66,7 @@ func main() { coordinate := kv.NewCoordinator(trx, r) pb.RegisterRawKVServer(gs, adapter.NewGRPCServer(s, coordinate)) pb.RegisterTransactionalKVServer(gs, adapter.NewGRPCServer(s, coordinate)) - pb.RegisterInternalServer(gs, adapter.NewInternal(trx, r)) + pb.RegisterInternalServer(gs, adapter.NewInternal(trx, r, coordinate.Clock())) tm.Register(gs) leaderhealth.Setup(r, gs, []string{"RawKV", "Example"}) diff --git a/store/mvcc_store.go b/store/mvcc_store.go new file mode 100644 index 0000000..4ff9cd5 --- /dev/null +++ b/store/mvcc_store.go @@ -0,0 +1,723 @@ +package store + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/gob" + "hash/crc32" + "io" + "log/slog" + "os" + "sync" + "time" + + "github.com/cockroachdb/errors" + "github.com/emirpasic/gods/maps/treemap" +) + +// VersionedValue represents a single committed version in MVCC storage. +type VersionedValue struct { + TS uint64 + Value []byte + Tombstone bool + ExpireAt uint64 // HLC timestamp; 0 means no TTL +} + +const ( + hlcLogicalBits = 16 + msPerSecond = 1000 +) + +func withinBoundsKey(k, start, end []byte) bool { + if start != nil && bytes.Compare(k, start) < 0 { + return false + } + if end != nil && bytes.Compare(k, end) > 0 { + return false + } + return true +} + +// mvccStore is an in-memory MVCC implementation backed by a treemap for +// deterministic iteration order and range scans. +type mvccStore struct { + tree *treemap.Map // key []byte -> []VersionedValue + mtx sync.RWMutex + log *slog.Logger + lastCommitTS uint64 + clock HybridClock +} + +// NewMVCCStore creates a new MVCC-enabled in-memory store. +func NewMVCCStore() MVCCStore { + return NewMVCCStoreWithClock(defaultHLC{}) +} + +// NewMVCCStoreWithClock allows injecting a hybrid clock (for tests or cluster-wide clocks). +func NewMVCCStoreWithClock(clock HybridClock) MVCCStore { + return &mvccStore{ + tree: treemap.NewWith(byteSliceComparator), + log: slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelWarn, + })), + clock: clock, + } +} + +type defaultHLC struct{} + +func nonNegativeMillis() uint64 { + nowMs := time.Now().UnixMilli() + if nowMs < 0 { + return 0 + } + return uint64(nowMs) +} + +func (defaultHLC) Now() uint64 { + return nonNegativeMillis() << hlcLogicalBits +} + +var _ MVCCStore = (*mvccStore)(nil) +var _ ScanStore = (*mvccStore)(nil) +var _ Store = (*mvccStore)(nil) + +// ---- helpers guarded by caller locks ---- + +func latestVisible(vs []VersionedValue, ts uint64) (VersionedValue, bool) { + for i := len(vs) - 1; i >= 0; i-- { + if vs[i].TS <= ts { + if vs[i].ExpireAt != 0 && vs[i].ExpireAt <= ts { + // Treat expired value as a tombstone for visibility. + return VersionedValue{}, false + } + return vs[i], true + } + } + return VersionedValue{}, false +} + +func visibleValue(versions []VersionedValue, ts uint64) ([]byte, bool) { + ver, ok := latestVisible(versions, ts) + if !ok || ver.Tombstone { + return nil, false + } + return ver.Value, true +} + +func visibleTxnValue(tv mvccTxnValue, now uint64) ([]byte, bool) { + if tv.tombstone { + return nil, false + } + if tv.expireAt != 0 && tv.expireAt <= now { + return nil, false + } + return tv.value, true +} + +func cloneKVPair(key, val []byte) *KVPair { + return &KVPair{ + Key: bytes.Clone(key), + Value: bytes.Clone(val), + } +} + +type iterEntry struct { + key []byte + ok bool + versions []VersionedValue + stageVal mvccTxnValue +} + +func nextBaseEntry(it *treemap.Iterator, start, end []byte) iterEntry { + for it.Next() { + k, ok := it.Key().([]byte) + if !ok { + continue + } + if !withinBoundsKey(k, start, end) { + if end != nil && bytes.Compare(k, end) > 0 { + return iterEntry{} + } + continue + } + versions, _ := it.Value().([]VersionedValue) + return iterEntry{key: k, ok: true, versions: versions} + } + return iterEntry{} +} + +func nextStageEntry(it *treemap.Iterator, start, end []byte) iterEntry { + for it.Next() { + k, ok := it.Key().([]byte) + if !ok { + continue + } + if !withinBoundsKey(k, start, end) { + if end != nil && bytes.Compare(k, end) > 0 { + return iterEntry{} + } + continue + } + val, _ := it.Value().(mvccTxnValue) + return iterEntry{key: k, ok: true, stageVal: val} + } + return iterEntry{} +} + +func (s *mvccStore) nextCommitTSLocked() uint64 { + return s.alignCommitTS(s.clock.Now()) +} + +func (s *mvccStore) putVersionLocked(key, value []byte, commitTS, expireAt uint64) { + existing, _ := s.tree.Get(key) + var versions []VersionedValue + if existing != nil { + versions, _ = existing.([]VersionedValue) + } + versions = append(versions, VersionedValue{ + TS: commitTS, + Value: bytes.Clone(value), + Tombstone: false, + ExpireAt: expireAt, + }) + s.tree.Put(bytes.Clone(key), versions) +} + +func (s *mvccStore) deleteVersionLocked(key []byte, commitTS uint64) { + existing, _ := s.tree.Get(key) + var versions []VersionedValue + if existing != nil { + versions, _ = existing.([]VersionedValue) + } + versions = append(versions, VersionedValue{ + TS: commitTS, + Value: nil, + Tombstone: true, + ExpireAt: 0, + }) + s.tree.Put(bytes.Clone(key), versions) +} + +func (s *mvccStore) ttlExpireAt(ttl int64) uint64 { + now := s.readTS() + if ttl <= 0 { + return now + } + // ttl is seconds; convert to milliseconds then shift to HLC layout. + deltaMs := uint64(ttl) * msPerSecond + return now + (deltaMs << hlcLogicalBits) +} + +func (s *mvccStore) readTS() uint64 { + now := s.clock.Now() + if now < s.lastCommitTS { + return s.lastCommitTS + } + return now +} + +func (s *mvccStore) alignCommitTS(commitTS uint64) uint64 { + ts := commitTS + read := s.readTS() + if ts < read { + ts = read + } + if ts <= s.lastCommitTS { + ts = s.lastCommitTS + 1 + } + s.lastCommitTS = ts + return ts +} + +func (s *mvccStore) latestVersionLocked(key []byte) (VersionedValue, bool) { + v, ok := s.tree.Get(key) + if !ok { + return VersionedValue{}, false + } + vs, _ := v.([]VersionedValue) + if len(vs) == 0 { + return VersionedValue{}, false + } + return vs[len(vs)-1], true +} + +// ---- MVCCStore methods ---- + +func (s *mvccStore) GetAt(ctx context.Context, key []byte, ts uint64) ([]byte, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + v, ok := s.tree.Get(key) + if !ok { + return nil, ErrKeyNotFound + } + versions, _ := v.([]VersionedValue) + ver, ok := latestVisible(versions, ts) + if !ok || ver.Tombstone { + return nil, ErrKeyNotFound + } + return bytes.Clone(ver.Value), nil +} + +func (s *mvccStore) LatestCommitTS(_ context.Context, key []byte) (uint64, bool, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + ver, ok := s.latestVersionLocked(key) + if !ok { + return 0, false, nil + } + return ver.TS, true, nil +} + +func (s *mvccStore) ApplyMutations(ctx context.Context, mutations []*KVPairMutation, startTS, commitTS uint64) error { + s.mtx.Lock() + defer s.mtx.Unlock() + + for _, mut := range mutations { + if latestVer, ok := s.latestVersionLocked(mut.Key); ok && latestVer.TS > startTS { + return errors.Wrapf(ErrWriteConflict, "key: %s", string(mut.Key)) + } + } + + commitTS = s.alignCommitTS(commitTS) + + for _, mut := range mutations { + switch mut.Op { + case OpTypePut: + s.putVersionLocked(mut.Key, mut.Value, commitTS, mut.ExpireAt) + case OpTypeDelete: + s.deleteVersionLocked(mut.Key, commitTS) + default: + return errors.WithStack(ErrUnknownOp) + } + s.log.InfoContext(ctx, "apply mutation", + slog.String("key", string(mut.Key)), + slog.Uint64("commit_ts", commitTS), + slog.Bool("delete", mut.Op == OpTypeDelete), + ) + } + + return nil +} + +// ---- Store / ScanStore methods ---- + +func (s *mvccStore) Get(_ context.Context, key []byte) ([]byte, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + now := s.readTS() + + v, ok := s.tree.Get(key) + if !ok { + return nil, ErrKeyNotFound + } + versions, _ := v.([]VersionedValue) + val, ok := visibleValue(versions, now) + if !ok { + return nil, ErrKeyNotFound + } + return bytes.Clone(val), nil +} + +func (s *mvccStore) Put(ctx context.Context, key []byte, value []byte) error { + s.mtx.Lock() + defer s.mtx.Unlock() + s.putVersionLocked(key, value, s.nextCommitTSLocked(), 0) + s.log.InfoContext(ctx, "put", + slog.String("key", string(key)), + slog.String("value", string(value)), + ) + return nil +} + +func (s *mvccStore) PutWithTTL(ctx context.Context, key []byte, value []byte, ttl int64) error { + s.mtx.Lock() + defer s.mtx.Unlock() + + exp := s.ttlExpireAt(ttl) + s.putVersionLocked(key, value, s.nextCommitTSLocked(), exp) + s.log.InfoContext(ctx, "put_ttl", + slog.String("key", string(key)), + slog.String("value", string(value)), + slog.Int64("ttl_sec", ttl), + ) + return nil +} + +func (s *mvccStore) Delete(ctx context.Context, key []byte) error { + s.mtx.Lock() + defer s.mtx.Unlock() + s.deleteVersionLocked(key, s.nextCommitTSLocked()) + s.log.InfoContext(ctx, "delete", + slog.String("key", string(key)), + ) + return nil +} + +func (s *mvccStore) Exists(_ context.Context, key []byte) (bool, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + now := s.readTS() + + v, ok := s.tree.Get(key) + if !ok { + return false, nil + } + versions, _ := v.([]VersionedValue) + if len(versions) == 0 { + return false, nil + } + ver, ok := latestVisible(versions, now) + if !ok { + return false, nil + } + return !ver.Tombstone, nil +} + +func (s *mvccStore) Expire(ctx context.Context, key []byte, ttl int64) error { + s.mtx.Lock() + defer s.mtx.Unlock() + + now := s.clock.Now() + v, ok := s.tree.Get(key) + if !ok { + return ErrKeyNotFound + } + versions, _ := v.([]VersionedValue) + if len(versions) == 0 { + return ErrKeyNotFound + } + ver := versions[len(versions)-1] + if ver.Tombstone || (ver.ExpireAt != 0 && ver.ExpireAt <= now) { + return ErrKeyNotFound + } + + exp := s.ttlExpireAt(ttl) + s.putVersionLocked(key, ver.Value, s.nextCommitTSLocked(), exp) + s.log.InfoContext(ctx, "expire", + slog.String("key", string(key)), + slog.Int64("ttl_sec", ttl), + ) + return nil +} + +func (s *mvccStore) Scan(_ context.Context, start []byte, end []byte, limit int) ([]*KVPair, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + if limit <= 0 { + return []*KVPair{}, nil + } + + now := s.readTS() + + capHint := limit + if size := s.tree.Size(); size < capHint { + capHint = size + } + if capHint < 0 { + capHint = 0 + } + + result := make([]*KVPair, 0, capHint) + s.tree.Each(func(key interface{}, value interface{}) { + if len(result) >= limit { + return + } + k, ok := key.([]byte) + if !ok || !withinBoundsKey(k, start, end) { + return + } + + versions, _ := value.([]VersionedValue) + val, ok := visibleValue(versions, now) + if !ok { + return + } + + result = append(result, &KVPair{ + Key: bytes.Clone(k), + Value: bytes.Clone(val), + }) + }) + + return result, nil +} + +func (s *mvccStore) Txn(ctx context.Context, f func(ctx context.Context, txn Txn) error) error { + s.mtx.Lock() + defer s.mtx.Unlock() + + txn := &mvccTxn{ + stage: treemap.NewWith(byteSliceComparator), + ops: []mvccOp{}, + s: s, + } + + if err := f(ctx, txn); err != nil { + return errors.WithStack(err) + } + + commitTS := s.nextCommitTSLocked() + for _, op := range txn.ops { + switch op.opType { + case OpTypePut: + s.putVersionLocked(op.key, op.value, commitTS, op.expireAt) + case OpTypeDelete: + s.deleteVersionLocked(op.key, commitTS) + default: + return errors.WithStack(ErrUnknownOp) + } + } + + return nil +} + +func (s *mvccStore) TxnWithTTL(ctx context.Context, f func(ctx context.Context, txn TTLTxn) error) error { + return s.Txn(ctx, func(ctx context.Context, txn Txn) error { + tt, ok := txn.(*mvccTxn) + if !ok { + return ErrNotSupported + } + return f(ctx, tt) + }) +} + +func (s *mvccStore) Snapshot() (io.ReadWriter, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + state := make([]mvccSnapshotEntry, 0, s.tree.Size()) + s.tree.Each(func(key interface{}, value interface{}) { + k, ok := key.([]byte) + if !ok { + return + } + versions, ok := value.([]VersionedValue) + if !ok { + return + } + state = append(state, mvccSnapshotEntry{ + Key: bytes.Clone(k), + Versions: append([]VersionedValue(nil), versions...), + }) + }) + + buf := &bytes.Buffer{} + if err := gob.NewEncoder(buf).Encode(state); err != nil { + return nil, errors.WithStack(err) + } + + sum := crc32.ChecksumIEEE(buf.Bytes()) + if err := binary.Write(buf, binary.LittleEndian, sum); err != nil { + return nil, errors.WithStack(err) + } + + return buf, nil +} + +func (s *mvccStore) Restore(r io.Reader) error { + data, err := io.ReadAll(r) + if err != nil { + return errors.WithStack(err) + } + if len(data) < checksumSize { + return errors.WithStack(ErrInvalidChecksum) + } + payload := data[:len(data)-checksumSize] + expected := binary.LittleEndian.Uint32(data[len(data)-checksumSize:]) + if crc32.ChecksumIEEE(payload) != expected { + return errors.WithStack(ErrInvalidChecksum) + } + + var state []mvccSnapshotEntry + if err := gob.NewDecoder(bytes.NewReader(payload)).Decode(&state); err != nil { + return errors.WithStack(err) + } + + s.mtx.Lock() + defer s.mtx.Unlock() + + s.tree.Clear() + for _, entry := range state { + versions := append([]VersionedValue(nil), entry.Versions...) + s.tree.Put(bytes.Clone(entry.Key), versions) + if len(versions) > 0 { + last := versions[len(versions)-1].TS + if last > s.lastCommitTS { + s.lastCommitTS = last + } + } + } + + return nil +} + +func (s *mvccStore) Close() error { + return nil +} + +// ---- transactional staging ---- + +type mvccOp struct { + opType OpType + key []byte + value []byte + expireAt uint64 +} + +type mvccTxn struct { + stage *treemap.Map // key []byte -> mvccTxnValue + ops []mvccOp + s *mvccStore +} + +type mvccTxnValue struct { + value []byte + tombstone bool + expireAt uint64 +} + +func (t *mvccTxn) Get(_ context.Context, key []byte) ([]byte, error) { + if v, ok := t.stage.Get(key); ok { + tv, _ := v.(mvccTxnValue) + if tv.tombstone { + return nil, ErrKeyNotFound + } + if tv.expireAt != 0 && tv.expireAt <= t.s.clock.Now() { + return nil, ErrKeyNotFound + } + return bytes.Clone(tv.value), nil + } + + return t.s.Get(context.Background(), key) +} + +func (t *mvccTxn) Put(_ context.Context, key []byte, value []byte) error { + t.stage.Put(key, mvccTxnValue{value: bytes.Clone(value)}) + t.ops = append(t.ops, mvccOp{opType: OpTypePut, key: bytes.Clone(key), value: bytes.Clone(value)}) + return nil +} + +func (t *mvccTxn) Delete(_ context.Context, key []byte) error { + t.stage.Put(key, mvccTxnValue{tombstone: true}) + t.ops = append(t.ops, mvccOp{opType: OpTypeDelete, key: bytes.Clone(key)}) + return nil +} + +func (t *mvccTxn) Exists(_ context.Context, key []byte) (bool, error) { + if v, ok := t.stage.Get(key); ok { + tv, _ := v.(mvccTxnValue) + if tv.expireAt != 0 && tv.expireAt <= t.s.clock.Now() { + return false, nil + } + return !tv.tombstone, nil + } + return t.s.Exists(context.Background(), key) +} + +func (t *mvccTxn) Expire(_ context.Context, key []byte, ttl int64) error { + exp := t.s.ttlExpireAt(ttl) + + if v, ok := t.stage.Get(key); ok { + tv, _ := v.(mvccTxnValue) + if tv.tombstone { + return ErrKeyNotFound + } + tv.expireAt = exp + t.stage.Put(key, tv) + t.ops = append(t.ops, mvccOp{opType: OpTypePut, key: bytes.Clone(key), value: bytes.Clone(tv.value), expireAt: exp}) + return nil + } + + val, err := t.s.Get(context.Background(), key) + if err != nil { + return err + } + t.stage.Put(key, mvccTxnValue{value: bytes.Clone(val), expireAt: exp}) + t.ops = append(t.ops, mvccOp{opType: OpTypePut, key: bytes.Clone(key), value: bytes.Clone(val), expireAt: exp}) + return nil +} + +func (t *mvccTxn) PutWithTTL(_ context.Context, key []byte, value []byte, ttl int64) error { + exp := t.s.ttlExpireAt(ttl) + t.stage.Put(key, mvccTxnValue{value: bytes.Clone(value), expireAt: exp}) + t.ops = append(t.ops, mvccOp{opType: OpTypePut, key: bytes.Clone(key), value: bytes.Clone(value), expireAt: exp}) + return nil +} + +func (t *mvccTxn) Scan(_ context.Context, start []byte, end []byte, limit int) ([]*KVPair, error) { + if limit <= 0 { + return []*KVPair{}, nil + } + + totalSize := t.s.tree.Size() + t.stage.Size() + capHint := limit + if totalSize < capHint { + capHint = totalSize + } + if capHint < 0 { + capHint = 0 + } + + result := make([]*KVPair, 0, capHint) + now := t.s.clock.Now() + + baseIt := t.s.tree.Iterator() + baseIt.Begin() + stageIt := t.stage.Iterator() + stageIt.Begin() + + result = mergeTxnEntries(result, limit, start, end, now, &baseIt, &stageIt) + + return result, nil +} + +func mergeTxnEntries(result []*KVPair, limit int, start []byte, end []byte, now uint64, baseIt, stageIt *treemap.Iterator) []*KVPair { + baseNext := nextBaseEntry(baseIt, start, end) + stageNext := nextStageEntry(stageIt, start, end) + + for len(result) < limit && (baseNext.ok || stageNext.ok) { + useStage := chooseStage(baseNext, stageNext) + + if useStage { + k := stageNext.key + if val, visible := visibleTxnValue(stageNext.stageVal, now); visible { + result = append(result, cloneKVPair(k, val)) + } + if baseNext.ok && bytes.Equal(baseNext.key, k) { + baseNext = nextBaseEntry(baseIt, start, end) + } + stageNext = nextStageEntry(stageIt, start, end) + continue + } + + if val, ok := visibleValue(baseNext.versions, now); ok { + result = append(result, cloneKVPair(baseNext.key, val)) + } + baseNext = nextBaseEntry(baseIt, start, end) + } + + return result +} + +func chooseStage(baseNext, stageNext iterEntry) bool { + if !baseNext.ok { + return stageNext.ok + } + if !stageNext.ok { + return false + } + return bytes.Compare(stageNext.key, baseNext.key) <= 0 +} + +// mvccSnapshotEntry is used solely for gob snapshot serialization. +type mvccSnapshotEntry struct { + Key []byte + Versions []VersionedValue +} diff --git a/store/mvcc_store_test.go b/store/mvcc_store_test.go new file mode 100644 index 0000000..f3f0b58 --- /dev/null +++ b/store/mvcc_store_test.go @@ -0,0 +1,69 @@ +package store + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type mockClock struct { + ts uint64 +} + +func (m *mockClock) Now() uint64 { return m.ts } +func (m *mockClock) advanceMs(ms uint64) { + m.ts += ms << hlcLogicalBits // HLC encodes milliseconds in high bits +} + +func TestMVCCStore_PutWithTTL_Expires(t *testing.T) { + ctx := context.Background() + clock := &mockClock{ts: 0} + st := NewMVCCStoreWithClock(clock) + + require.NoError(t, st.PutWithTTL(ctx, []byte("k"), []byte("v"), 1)) + + v, err := st.Get(ctx, []byte("k")) + require.NoError(t, err) + require.Equal(t, []byte("v"), v) + + clock.advanceMs(1500) // 1.5s later + _, err = st.Get(ctx, []byte("k")) + require.ErrorIs(t, err, ErrKeyNotFound) +} + +func TestMVCCStore_ExpireExisting(t *testing.T) { + ctx := context.Background() + clock := &mockClock{ts: 0} + st := NewMVCCStoreWithClock(clock) + + require.NoError(t, st.Put(ctx, []byte("k"), []byte("v"))) + require.NoError(t, st.Expire(ctx, []byte("k"), 1)) + + clock.advanceMs(500) + v, err := st.Get(ctx, []byte("k")) + require.NoError(t, err) + require.Equal(t, []byte("v"), v) + + clock.advanceMs(600) // total 1.1s + _, err = st.Get(ctx, []byte("k")) + require.ErrorIs(t, err, ErrKeyNotFound) +} + +func TestMVCCStore_TxnWithTTL(t *testing.T) { + ctx := context.Background() + clock := &mockClock{ts: 0} + st := NewMVCCStoreWithClock(clock) + + require.NoError(t, st.TxnWithTTL(ctx, func(ctx context.Context, txn TTLTxn) error { + return txn.PutWithTTL(ctx, []byte("k"), []byte("v"), 1) + })) + + v, err := st.Get(ctx, []byte("k")) + require.NoError(t, err) + require.Equal(t, []byte("v"), v) + + clock.advanceMs(1100) + _, err = st.Get(ctx, []byte("k")) + require.ErrorIs(t, err, ErrKeyNotFound) +} diff --git a/store/store.go b/store/store.go index c9f818b..8d5be42 100644 --- a/store/store.go +++ b/store/store.go @@ -11,6 +11,8 @@ var ErrKeyNotFound = errors.New("not found") var ErrUnknownOp = errors.New("unknown op") var ErrNotSupported = errors.New("not supported") var ErrInvalidChecksum = errors.New("invalid checksum") +var ErrWriteConflict = errors.New("write conflict") +var ErrExpired = errors.New("expired") type KVPair struct { Key []byte @@ -35,6 +37,38 @@ type ScanStore interface { Scan(ctx context.Context, start []byte, end []byte, limit int) ([]*KVPair, error) } +// HybridClock provides monotonically increasing timestamps (HLC). +type HybridClock interface { + Now() uint64 +} + +// MVCCStore extends Store with multi-version concurrency control helpers. +// Reads can be evaluated at an arbitrary timestamp, and commits validate +// conflicts against the latest committed version. +type MVCCStore interface { + ScanStore + TTLStore + + // GetAt returns the newest version whose commit timestamp is <= ts. + GetAt(ctx context.Context, key []byte, ts uint64) ([]byte, error) + // LatestCommitTS returns the commit timestamp of the newest version. + // The boolean reports whether the key has any version. + LatestCommitTS(ctx context.Context, key []byte) (uint64, bool, error) + // ApplyMutations atomically validates and appends the provided mutations. + // It must return ErrWriteConflict if any key has a newer commit timestamp + // than startTS. + ApplyMutations(ctx context.Context, mutations []*KVPairMutation, startTS, commitTS uint64) error +} + +// KVPairMutation is a small helper struct for MVCC mutation application. +type KVPairMutation struct { + Op OpType + Key []byte + Value []byte + // ExpireAt is an HLC timestamp; 0 means no TTL. + ExpireAt uint64 +} + type TTLStore interface { Store Expire(ctx context.Context, key []byte, ttl int64) error