Skip to content

Commit a625991

Browse files
authored
Merge pull request #2764 from opentensor/fix/roman/balance-unit
Improve logic in Balance magic methods
2 parents ff4ace4 + a1d034f commit a625991

File tree

3 files changed

+118
-58
lines changed

3 files changed

+118
-58
lines changed

bittensor/utils/balance.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,35 @@
66
from bittensor.core import settings
77

88

9+
def _check_currencies(self, other):
10+
"""Checks that Balance objects have the same netuids to perform arithmetic operations.
11+
12+
A warning is raised if the netuids differ.
13+
14+
Example:
15+
>>> balance1 = Balance.from_rao(1000).set_unit(12)
16+
>>> balance2 = Balance.from_rao(500).set_unit(12)
17+
>>> balance1 + balance2 # No warning
18+
19+
>>> balance3 = Balance.from_rao(200).set_unit(15)
20+
>>> balance1 + balance3 # Raises DeprecationWarning
21+
22+
In this example:
23+
- `from_rao` creates a Balance instance from the amount in rao (smallest unit).
24+
- `set_unit(12)` sets the unit to correspond to subnet 12 (i.e., Alpha from netuid 12).
25+
"""
26+
if self.netuid != other.netuid:
27+
warnings.simplefilter("default", DeprecationWarning)
28+
warnings.warn(
29+
"Balance objects must have the same netuid (Alpha currency) to perform arithmetic operations.\n"
30+
f"First balance is `{self}`. Second balance is `{other}`.\n\n"
31+
"To create a Balance instance with the correct netuid, use:\n"
32+
"Balance.from_rao(1000).set_unit(12) # 1000 rao in subnet 12",
33+
category=DeprecationWarning,
34+
stacklevel=2,
35+
)
36+
37+
938
class Balance:
1039
"""
1140
Represents the bittensor balance of the wallet, stored as rao (int).
@@ -23,6 +52,7 @@ class Balance:
2352
rao_unit: str = settings.RAO_SYMBOL # This is the rao unit
2453
rao: int
2554
tao: float
55+
netuid: int = 0
2656

2757
def __init__(self, balance: Union[int, float]):
2858
"""
@@ -78,7 +108,8 @@ def __eq__(self, other: Union[int, float, "Balance"]):
78108
if other is None:
79109
return False
80110

81-
if hasattr(other, "rao"):
111+
if isinstance(other, Balance):
112+
_check_currencies(self, other)
82113
return self.rao == other.rao
83114
else:
84115
try:
@@ -92,7 +123,8 @@ def __ne__(self, other: Union[int, float, "Balance"]):
92123
return not self == other
93124

94125
def __gt__(self, other: Union[int, float, "Balance"]):
95-
if hasattr(other, "rao"):
126+
if isinstance(other, Balance):
127+
_check_currencies(self, other)
96128
return self.rao > other.rao
97129
else:
98130
try:
@@ -103,7 +135,8 @@ def __gt__(self, other: Union[int, float, "Balance"]):
103135
raise NotImplementedError("Unsupported type")
104136

105137
def __lt__(self, other: Union[int, float, "Balance"]):
106-
if hasattr(other, "rao"):
138+
if isinstance(other, Balance):
139+
_check_currencies(self, other)
107140
return self.rao < other.rao
108141
else:
109142
try:
@@ -115,111 +148,129 @@ def __lt__(self, other: Union[int, float, "Balance"]):
115148

116149
def __le__(self, other: Union[int, float, "Balance"]):
117150
try:
151+
if isinstance(other, Balance):
152+
_check_currencies(self, other)
118153
return self < other or self == other
119154
except TypeError:
120155
raise NotImplementedError("Unsupported type")
121156

122157
def __ge__(self, other: Union[int, float, "Balance"]):
123158
try:
159+
if isinstance(other, Balance):
160+
_check_currencies(self, other)
124161
return self > other or self == other
125162
except TypeError:
126163
raise NotImplementedError("Unsupported type")
127164

128165
def __add__(self, other: Union[int, float, "Balance"]):
129-
if hasattr(other, "rao"):
130-
return Balance.from_rao(int(self.rao + other.rao))
166+
if isinstance(other, Balance):
167+
_check_currencies(self, other)
168+
return Balance.from_rao(int(self.rao + other.rao)).set_unit(self.netuid)
131169
else:
132170
try:
133171
# Attempt to cast to int from rao
134-
return Balance.from_rao(int(self.rao + other))
172+
return Balance.from_rao(int(self.rao + other)).set_unit(self.netuid)
135173
except (ValueError, TypeError):
136174
raise NotImplementedError("Unsupported type")
137175

138176
def __radd__(self, other: Union[int, float, "Balance"]):
139177
try:
178+
if isinstance(other, Balance):
179+
_check_currencies(self, other)
140180
return self + other
141181
except TypeError:
142182
raise NotImplementedError("Unsupported type")
143183

144184
def __sub__(self, other: Union[int, float, "Balance"]):
145185
try:
186+
if isinstance(other, Balance):
187+
_check_currencies(self, other)
146188
return self + -other
147189
except TypeError:
148190
raise NotImplementedError("Unsupported type")
149191

150192
def __rsub__(self, other: Union[int, float, "Balance"]):
151193
try:
194+
if isinstance(other, Balance):
195+
_check_currencies(self, other)
152196
return -self + other
153197
except TypeError:
154198
raise NotImplementedError("Unsupported type")
155199

156200
def __mul__(self, other: Union[int, float, "Balance"]):
157-
if hasattr(other, "rao"):
158-
return Balance.from_rao(int(self.rao * other.rao))
201+
if isinstance(other, Balance):
202+
_check_currencies(self, other)
203+
return Balance.from_rao(int(self.rao * other.rao)).set_unit(self.netuid)
159204
else:
160205
try:
161206
# Attempt to cast to int from rao
162-
return Balance.from_rao(int(self.rao * other))
207+
return Balance.from_rao(int(self.rao * other)).set_unit(self.netuid)
163208
except (ValueError, TypeError):
164209
raise NotImplementedError("Unsupported type")
165210

166211
def __rmul__(self, other: Union[int, float, "Balance"]):
212+
if isinstance(other, Balance):
213+
_check_currencies(self, other)
167214
return self * other
168215

169216
def __truediv__(self, other: Union[int, float, "Balance"]):
170-
if hasattr(other, "rao"):
171-
return Balance.from_rao(int(self.rao / other.rao))
217+
if isinstance(other, Balance):
218+
_check_currencies(self, other)
219+
return Balance.from_rao(int(self.rao / other.rao)).set_unit(self.netuid)
172220
else:
173221
try:
174222
# Attempt to cast to int from rao
175-
return Balance.from_rao(int(self.rao / other))
223+
return Balance.from_rao(int(self.rao / other)).set_unit(self.netuid)
176224
except (ValueError, TypeError):
177225
raise NotImplementedError("Unsupported type")
178226

179227
def __rtruediv__(self, other: Union[int, float, "Balance"]):
180-
if hasattr(other, "rao"):
181-
return Balance.from_rao(int(other.rao / self.rao))
228+
if isinstance(other, Balance):
229+
_check_currencies(self, other)
230+
return Balance.from_rao(int(other.rao / self.rao)).set_unit(self.netuid)
182231
else:
183232
try:
184233
# Attempt to cast to int from rao
185-
return Balance.from_rao(int(other / self.rao))
234+
return Balance.from_rao(int(other / self.rao)).set_unit(self.netuid)
186235
except (ValueError, TypeError):
187236
raise NotImplementedError("Unsupported type")
188237

189238
def __floordiv__(self, other: Union[int, float, "Balance"]):
190-
if hasattr(other, "rao"):
191-
return Balance.from_rao(int(self.tao // other.tao))
239+
if isinstance(other, Balance):
240+
_check_currencies(self, other)
241+
return Balance.from_rao(int(self.tao // other.tao)).set_unit(self.netuid)
192242
else:
193243
try:
194244
# Attempt to cast to int from rao
195-
return Balance.from_rao(int(self.rao // other))
245+
return Balance.from_rao(int(self.rao // other)).set_unit(self.netuid)
196246
except (ValueError, TypeError):
197247
raise NotImplementedError("Unsupported type")
198248

199249
def __rfloordiv__(self, other: Union[int, float, "Balance"]):
200-
if hasattr(other, "rao"):
201-
return Balance.from_rao(int(other.rao // self.rao))
250+
if isinstance(other, Balance):
251+
_check_currencies(self, other)
252+
return Balance.from_rao(int(other.rao // self.rao)).set_unit(self.netuid)
202253
else:
203254
try:
204255
# Attempt to cast to int from rao
205-
return Balance.from_rao(int(other // self.rao))
256+
return Balance.from_rao(int(other // self.rao)).set_unit(self.netuid)
206257
except (ValueError, TypeError):
207258
raise NotImplementedError("Unsupported type")
208259

209260
def __nonzero__(self) -> bool:
210261
return bool(self.rao)
211262

212263
def __neg__(self):
213-
return Balance.from_rao(-self.rao)
264+
return Balance.from_rao(-self.rao).set_unit(self.netuid)
214265

215266
def __pos__(self):
216-
return Balance.from_rao(self.rao)
267+
return Balance.from_rao(self.rao).set_unit(self.netuid)
217268

218269
def __abs__(self):
219-
return Balance.from_rao(abs(self.rao))
270+
return Balance.from_rao(abs(self.rao)).set_unit(self.netuid)
220271

221272
@staticmethod
222-
def from_float(amount: float, netuid: int = 0):
273+
def from_float(amount: float, netuid: int = 0) -> "Balance":
223274
"""
224275
Given tao, return :func:`Balance` object with rao(``int``) and tao(``float``), where rao = int(tao*pow(10,9))
225276
Args:
@@ -233,7 +284,7 @@ def from_float(amount: float, netuid: int = 0):
233284
return Balance(rao_).set_unit(netuid)
234285

235286
@staticmethod
236-
def from_tao(amount: float, netuid: int = 0):
287+
def from_tao(amount: float, netuid: int = 0) -> "Balance":
237288
"""
238289
Given tao, return Balance object with rao(``int``) and tao(``float``), where rao = int(tao*pow(10,9))
239290
@@ -248,7 +299,7 @@ def from_tao(amount: float, netuid: int = 0):
248299
return Balance(rao_).set_unit(netuid)
249300

250301
@staticmethod
251-
def from_rao(amount: int, netuid: int = 0):
302+
def from_rao(amount: int, netuid: int = 0) -> "Balance":
252303
"""
253304
Given rao, return Balance object with rao(``int``) and tao(``float``), where rao = int(tao*pow(10,9))
254305
@@ -262,7 +313,7 @@ def from_rao(amount: int, netuid: int = 0):
262313
return Balance(amount).set_unit(netuid)
263314

264315
@staticmethod
265-
def get_unit(netuid: int):
316+
def get_unit(netuid: int) -> str:
266317
base = len(units)
267318
if netuid < base:
268319
return units[netuid]
@@ -274,6 +325,7 @@ def get_unit(netuid: int):
274325
return result
275326

276327
def set_unit(self, netuid: int):
328+
self.netuid = netuid
277329
self.unit = Balance.get_unit(netuid)
278330
self.rao_unit = Balance.get_unit(netuid)
279331
return self
@@ -777,18 +829,18 @@ def fixed_to_float(
777829
]
778830

779831

780-
def tao(amount: float) -> Balance:
832+
def tao(amount: float, netuid: int = 0) -> Balance:
781833
"""
782834
Helper function to create a Balance object from a float (Tao)
783835
"""
784-
return Balance.from_tao(amount)
836+
return Balance.from_tao(amount).set_unit(netuid)
785837

786838

787-
def rao(amount: int) -> Balance:
839+
def rao(amount: int, netuid: int = 0) -> Balance:
788840
"""
789841
Helper function to create a Balance object from an int (Rao)
790842
"""
791-
return Balance.from_rao(amount)
843+
return Balance.from_rao(amount).set_unit(netuid)
792844

793845

794846
def check_and_convert_to_balance(

0 commit comments

Comments
 (0)