From eb5379ee9345994950ac7d732e861e9344186f23 Mon Sep 17 00:00:00 2001 From: Maxim Bladyko Date: Wed, 24 Jan 2024 22:14:09 +0300 Subject: [PATCH 1/2] Moved load_state() to the neuron.__init__ --- neurons/validator.py | 1 - template/base/neuron.py | 2 ++ template/base/validator.py | 5 +++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/neurons/validator.py b/neurons/validator.py index 7b502029..d3ee8da0 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -44,7 +44,6 @@ def __init__(self, config=None): super(Validator, self).__init__(config=config) bt.logging.info("load_state()") - self.load_state() # TODO(developer): Anything specific to your use case you can do here diff --git a/template/base/neuron.py b/template/base/neuron.py index ef2caf05..eb0a6542 100644 --- a/template/base/neuron.py +++ b/template/base/neuron.py @@ -99,6 +99,8 @@ def __init__(self, config=None): ) self.step = 0 + self.load_state() + @abstractmethod async def forward(self, synapse: bt.Synapse) -> bt.Synapse: ... diff --git a/template/base/validator.py b/template/base/validator.py index ec069d47..12a7df46 100644 --- a/template/base/validator.py +++ b/template/base/validator.py @@ -19,6 +19,7 @@ import copy +import os import torch import asyncio import threading @@ -339,6 +340,10 @@ def load_state(self): """Loads the state of the validator from a file.""" bt.logging.info("Loading validator state.") + if not os.path.exists(self.config.neuron.full_path + "/state.pt"): + bt.logging.warning("No saved state found") + return + # Load the state of the validator from file. state = torch.load(self.config.neuron.full_path + "/state.pt") self.step = state["step"] From ec7827fdb2a4b4e6e2d74b8522b77069c40306dd Mon Sep 17 00:00:00 2001 From: Maxim Bladyko Date: Thu, 25 Jan 2024 11:26:22 +0300 Subject: [PATCH 2/2] Moving scores to the device on load state --- template/base/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/template/base/validator.py b/template/base/validator.py index 12a7df46..00f592af 100644 --- a/template/base/validator.py +++ b/template/base/validator.py @@ -347,5 +347,5 @@ def load_state(self): # Load the state of the validator from file. state = torch.load(self.config.neuron.full_path + "/state.pt") self.step = state["step"] - self.scores = state["scores"] + self.scores = state["scores"].to(self.device) self.hotkeys = state["hotkeys"]