diff --git a/template/base/neuron.py b/template/base/neuron.py index 364bdcdf..0c30478b 100644 --- a/template/base/neuron.py +++ b/template/base/neuron.py @@ -108,6 +108,8 @@ def __init__(self, config=None): ) self.step = 0 + self._last_updated_block = self.metagraph.last_update[self.uid] + @abstractmethod async def forward(self, synapse: bt.Synapse) -> bt.Synapse: ... @@ -125,9 +127,11 @@ def sync(self): if self.should_sync_metagraph(): self.resync_metagraph() + self._last_updated_block = self.block if self.should_set_weights(): self.set_weights() + self._last_updated_block = self.block # Always save state. self.save_state() @@ -148,9 +152,10 @@ def should_sync_metagraph(self): """ Check if enough epoch blocks have elapsed since the last checkpoint to sync. """ - return ( - self.block - self.metagraph.last_update[self.uid] - ) > self.config.neuron.epoch_length + elapsed = self.block - self._last_updated_block + + # Only set weights if epoch has passed + return elapsed > self.config.neuron.epoch_length def should_set_weights(self) -> bool: # Don't set weights on initialization. @@ -161,12 +166,10 @@ def should_set_weights(self) -> bool: if self.config.neuron.disable_set_weights: return False - # Define appropriate logic for when set weights. - return ( - (self.block - self.metagraph.last_update[self.uid]) - > self.config.neuron.epoch_length - and self.neuron_type != "MinerNeuron" - ) # don't set weights if you're a miner + elapsed = self.block - self._last_updated_block + + # Only set weights if epoch has passed and this isn't a MinerNeuron. + return elapsed > self.config.neuron.epoch_length and self.neuron_type != "MinerNeuron" def save_state(self): bt.logging.trace(