44package state
55
66import (
7- "errors"
87 "fmt"
98
109 "github.com/google/btree"
@@ -14,8 +13,6 @@ import (
1413 "github.com/ava-labs/avalanchego/utils/iterator"
1514)
1615
17- var ErrAddingStakerAfterDeletion = errors.New("attempted to add a staker after deleting it")
18-
1916type Stakers interface {
2017 CurrentStakers
2118 PendingStakers
@@ -142,8 +139,7 @@ func (v *baseStakers) PutValidator(staker *Staker) {
142139 validator.validator = staker
143140
144141 validatorDiff := v.getOrCreateValidatorDiff(staker.SubnetID, staker.NodeID)
145- validatorDiff.validatorStatus = added
146- validatorDiff.validator = staker
142+ validatorDiff.added = staker
147143
148144 v.stakers.ReplaceOrInsert(staker)
149145}
@@ -154,8 +150,8 @@ func (v *baseStakers) DeleteValidator(staker *Staker) {
154150 v.pruneValidator(staker.SubnetID, staker.NodeID)
155151
156152 validatorDiff := v.getOrCreateValidatorDiff(staker.SubnetID, staker.NodeID)
157- validatorDiff.validatorStatus = deleted
158- validatorDiff.validator = staker
153+ validatorDiff.added = nil
154+ validatorDiff.removed = staker
159155
160156 v.stakers.Delete(staker)
161157}
@@ -247,9 +243,7 @@ func (v *baseStakers) getOrCreateValidatorDiff(subnetID ids.ID, nodeID ids.NodeI
247243 }
248244 validatorDiff, ok := subnetValidatorDiffs[nodeID]
249245 if !ok {
250- validatorDiff = &diffValidator{
251- validatorStatus: unmodified,
252- }
246+ validatorDiff = &diffValidator{}
253247 subnetValidatorDiffs[nodeID] = validatorDiff
254248 }
255249 return validatorDiff
@@ -263,23 +257,43 @@ type diffStakers struct {
263257}
264258
265259type diffValidator struct {
266- // validatorStatus describes whether a validator has been added or removed.
267- //
268- // validatorStatus is not affected by delegators ops so unmodified does not
269- // mean that diffValidator hasn't change, since delegators may have changed.
270- validatorStatus diffValidatorStatus
271- validator *Staker
272-
260+ // added represents a validator that was added in this diff, or nil if no
261+ // validator was added. Can be non-nil at the same time as removed to represent a replacement.
262+ added *Staker
263+ // removed represents a validator that was removed in this diff, or nil if no
264+ // validator was removed. Can be non-nil at the same time as added to represent a replacement.
265+ removed *Staker
273266 addedDelegators *btree.BTreeG[*Staker]
274267 deletedDelegators map[ids.ID]*Staker
275268}
276269
270+ // validatorStatus returns the status of the validator in this diff.
271+ //
272+ // validatorStatus is not affected by delegator ops so unmodified does not
273+ // mean that diffValidator hasn't changed, since delegators may have changed.
274+ func (d *diffValidator) validatorStatus() diffValidatorStatus {
275+ // If both added and removed are non-nil, this represents a replacement,
276+ // so we just return the added validator's status, we don't need the removed validator's status.
277+ if d.added != nil {
278+ return added
279+ }
280+ if d.removed != nil {
281+ return deleted
282+ }
283+ return unmodified
284+ }
285+
277286func (d *diffValidator) WeightDiff() (ValidatorWeightDiff, error) {
278- weightDiff := ValidatorWeightDiff{
279- Decrease: d.validatorStatus == deleted,
287+ var weightDiff ValidatorWeightDiff
288+ if d.added != nil {
289+ if err := weightDiff.Add(d.added.Weight); err != nil {
290+ return ValidatorWeightDiff{}, fmt.Errorf("failed to increase weight of added validator: %w", err)
291+ }
280292 }
281- if d.validatorStatus != unmodified {
282- weightDiff.Amount = d.validator.Weight
293+ if d.removed != nil {
294+ if err := weightDiff.Sub(d.removed.Weight); err != nil {
295+ return ValidatorWeightDiff{}, fmt.Errorf("failed to decrease weight of deleted validator: %w", err)
296+ }
283297 }
284298
285299 for _, staker := range d.deletedDelegators {
@@ -304,7 +318,6 @@ func (d *diffValidator) WeightDiff() (ValidatorWeightDiff, error) {
304318
305319// GetValidator attempts to fetch the validator with the given subnetID and
306320// nodeID.
307- // Invariant: Assumes that the validator will never be removed and then added.
308321func (s *diffStakers) GetValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, diffValidatorStatus) {
309322 subnetValidatorDiffs, ok := s.validatorDiffs[subnetID]
310323 if !ok {
@@ -316,22 +329,40 @@ func (s *diffStakers) GetValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker,
316329 return nil, unmodified
317330 }
318331
319- if validatorDiff.validatorStatus == added {
320- return validatorDiff.validator, added
321- }
322- return nil, validatorDiff.validatorStatus
332+ return validatorDiff.added, validatorDiff.validatorStatus()
323333}
324334
325335func (s *diffStakers) PutValidator(staker *Staker) error {
326336 validatorDiff := s.getOrCreateDiff(staker.SubnetID, staker.NodeID)
327- if validatorDiff.validatorStatus == deleted {
328- // Enforce the invariant that a validator cannot be added after being
329- // deleted.
330- return ErrAddingStakerAfterDeletion
337+ if validatorDiff.removed != nil {
338+ // TLDR: The validator was previously deleted, so we remove it from the
339+ // deleted stakers set.
340+
341+ // We set the removed field when we delete the validator that was not added in this diff before.
342+ // So if we reached here, it means we removed it first and now either re-adding it or updating it.
343+ // Either way, we should remove it from the deleted stakers set.
344+ delete(s.deletedStakers, validatorDiff.removed.TxID)
345+ if len(s.deletedStakers) == 0 {
346+ s.deletedStakers = nil
347+ }
348+
349+ // If we're re-adding the exact same validator that was removed,
350+ // the two operations cancel out.
351+ if validatorDiff.removed.Equals(staker) {
352+ validatorDiff.removed = nil
353+ return nil
354+ }
355+
356+ // Attention: We do not return here, but combine with the rest of the flow.
331357 }
332358
333- validatorDiff.validatorStatus = added
334- validatorDiff.validator = staker
359+ validatorDiff.added = staker
360+
361+ // Ensure the newly added staker is not tracked as deleted.
362+ delete(s.deletedStakers, staker.TxID)
363+ if len(s.deletedStakers) == 0 {
364+ s.deletedStakers = nil
365+ }
335366
336367 if s.addedStakers == nil {
337368 s.addedStakers = btree.NewG(defaultTreeDegree, (*Staker).Less)
@@ -342,15 +373,30 @@ func (s *diffStakers) PutValidator(staker *Staker) error {
342373
343374func (s *diffStakers) DeleteValidator(staker *Staker) {
344375 validatorDiff := s.getOrCreateDiff(staker.SubnetID, staker.NodeID)
345- if validatorDiff.validatorStatus == added {
346- // This validator was added and immediately removed in this diff. We
347- // treat it as if it was never added.
348- validatorDiff.validatorStatus = unmodified
349- s.addedStakers.Delete(validatorDiff.validator)
350- validatorDiff.validator = nil
376+ if validatorDiff.added != nil {
377+ // This validator was added in this diff. Rollback the addition.
378+ s.addedStakers.Delete(validatorDiff.added)
379+ validatorDiff.added = nil
380+
381+ // TLDR: If there was a previously removed validator, re-add it to
382+ // deletedStakers since the replacement is being undone.
383+
384+ // We set the removed field when we delete the validator that was not added in this diff before.
385+ // Since we reached here, we have first deleted it, and then added it.
386+ // When we deleted it, we set the removed field and add it to the deleted stakers set.
387+ // When we added it, we removed it from the deleted stakers set.
388+ // Since we're now deleting it again, we should add it back to the deleted stakers set.
389+ // Why are we putting back the validator that was originally removed and not the new staker?
390+ // Because the original staker, validatorDiff.removed was there at the beginning, and the second
391+ // removal is just rolling back the addition of the new staker. We therefore put back original staker.
392+ if validatorDiff.removed != nil {
393+ if s.deletedStakers == nil {
394+ s.deletedStakers = make(map[ids.ID]*Staker)
395+ }
396+ s.deletedStakers[validatorDiff.removed.TxID] = validatorDiff.removed
397+ }
351398 } else {
352- validatorDiff.validatorStatus = deleted
353- validatorDiff.validator = staker
399+ validatorDiff.removed = staker
354400 if s.deletedStakers == nil {
355401 s.deletedStakers = make(map[ids.ID]*Staker)
356402 }
@@ -438,9 +484,7 @@ func (s *diffStakers) getOrCreateDiff(subnetID ids.ID, nodeID ids.NodeID) *diffV
438484 }
439485 validatorDiff, ok := subnetValidatorDiffs[nodeID]
440486 if !ok {
441- validatorDiff = &diffValidator{
442- validatorStatus: unmodified,
443- }
487+ validatorDiff = &diffValidator{}
444488 subnetValidatorDiffs[nodeID] = validatorDiff
445489 }
446490 return validatorDiff
0 commit comments