Skip to content

Commit cdf89a5

Browse files
authored
fix: INVALID_NODE_ACCOUNT during node switching (#93)
* refactor: improve transaction handling and signing Signed-off-by: dosi <[email protected]> * refactor: update query to build correctly payment transaction Signed-off-by: dosi <[email protected]> * test: fix unit tests Signed-off-by: dosi <[email protected]> * refactor: update _require_frozen to properly check if dict is empty Signed-off-by: dosi <[email protected]> * refactor: make transaction_body_bytes and signature_map private in Transaction class Signed-off-by: dosi <[email protected]> * docs: add comments to Transaction class Signed-off-by: dosi <[email protected]> * refactor: update query and query_payment to use the right Transaction attributes Signed-off-by: dosi <[email protected]> * test: fix unit tests Signed-off-by: dosi <[email protected]> * refactor: update Transaction class to address PR comments Signed-off-by: dosi <[email protected]> * revert: remove unintended changes from test_executable.py Signed-off-by: dosi <[email protected]> --------- Signed-off-by: dosi <[email protected]>
1 parent 43d19ed commit cdf89a5

15 files changed

+179
-64
lines changed

src/hiero_sdk_python/query/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _build_query_payment_transaction(self, payer_account_id, payer_private_key,
145145
tx.transaction_id = TransactionId.generate(payer_account_id)
146146

147147
body_bytes = tx.build_transaction_body().SerializeToString()
148-
tx.transaction_body_bytes = body_bytes
148+
tx._transaction_body_bytes.setdefault(node_account_id, body_bytes)
149149
tx.sign(payer_private_key)
150150

151151
return tx.to_proto()

src/hiero_sdk_python/transaction/query_payment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_query_payment_transaction(
2626
tx.transaction_id = TransactionId.generate(payer_account_id)
2727

2828
body_bytes = tx.build_transaction_body().SerializeToString()
29-
tx.transaction_body_bytes = body_bytes
29+
tx._transaction_body_bytes.setdefault(node_account_id, body_bytes)
3030

3131
tx.sign(payer_private_key)
3232
return tx.to_proto()

src/hiero_sdk_python/transaction/transaction.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22

3+
from hiero_sdk_python.account.account_id import AccountId
34
from hiero_sdk_python.exceptions import PrecheckError
45
from hiero_sdk_python.executable import _Executable, _ExecutionState
56
from hiero_sdk_python.hapi.services import (basic_types_pb2, transaction_body_pb2, transaction_contents_pb2, transaction_pb2)
@@ -8,7 +9,6 @@
89
from hiero_sdk_python.transaction.transaction_id import TransactionId
910
from hiero_sdk_python.transaction.transaction_response import TransactionResponse
1011

11-
1212
class Transaction(_Executable):
1313
"""
1414
Base class for all Hedera transactions.
@@ -34,8 +34,17 @@ def __init__(self):
3434
self.transaction_valid_duration = 120
3535
self.generate_record = False
3636
self.memo = ""
37-
self.transaction_body_bytes = None
38-
self.signature_map = basic_types_pb2.SignatureMap()
37+
# Maps each node's AccountId to its corresponding transaction body bytes
38+
# This allows us to maintain separate transaction bodies for each node
39+
# which is necessary in case node is unhealthy and we have to switch it with other node.
40+
# Each transaction body has the AccountId of the node it's being submitted to.
41+
# If these do not match `INVALID_NODE_ACCOUNT` error will occur.
42+
self._transaction_body_bytes: dict[AccountId, bytes] = {}
43+
44+
# Maps transaction body bytes to their associated signatures
45+
# This allows us to maintain the signatures for each unique transaction
46+
# and ensures that the correct signatures are used when submitting transactions
47+
self._signature_map: dict[bytes, basic_types_pb2.SignatureMap] = {}
3948
self._default_transaction_fee = 2_000_000
4049
self.operator_account_id = None
4150

@@ -148,18 +157,24 @@ def sign(self, private_key):
148157
"""
149158
# We require the transaction to be frozen before signing
150159
self._require_frozen()
160+
161+
# We sign the bodies for each node in case we need to switch nodes during execution.
162+
for body_bytes in self._transaction_body_bytes.values():
163+
signature = private_key.sign(body_bytes)
151164

152-
signature = private_key.sign(self.transaction_body_bytes)
153-
154-
public_key_bytes = private_key.public_key().to_bytes_raw()
165+
public_key_bytes = private_key.public_key().to_bytes_raw()
155166

156-
sig_pair = basic_types_pb2.SignaturePair(
157-
pubKeyPrefix=public_key_bytes,
158-
ed25519=signature
159-
)
167+
sig_pair = basic_types_pb2.SignaturePair(
168+
pubKeyPrefix=public_key_bytes,
169+
ed25519=signature
170+
)
160171

161-
self.signature_map.sigPair.append(sig_pair)
172+
# We initialize the signature map for this body_bytes if it doesn't exist yet
173+
self._signature_map.setdefault(body_bytes, basic_types_pb2.SignatureMap())
162174

175+
# Append the signature pair to the signature map for this transaction body
176+
self._signature_map[body_bytes].sigPair.append(sig_pair)
177+
163178
return self
164179

165180
def to_proto(self):
@@ -175,9 +190,17 @@ def to_proto(self):
175190
# We require the transaction to be frozen before converting to protobuf
176191
self._require_frozen()
177192

193+
body_bytes = self._transaction_body_bytes.get(self.node_account_id)
194+
if body_bytes is None:
195+
raise ValueError(f"No transaction body found for node {self.node_account_id}")
196+
197+
sig_map = self._signature_map.get(body_bytes)
198+
if sig_map is None:
199+
raise ValueError("No signature map found for the current transaction body")
200+
178201
signed_transaction = transaction_contents_pb2.SignedTransaction(
179-
bodyBytes=self.transaction_body_bytes,
180-
sigMap=self.signature_map
202+
bodyBytes=body_bytes,
203+
sigMap=sig_map
181204
)
182205

183206
return transaction_pb2.Transaction(
@@ -197,18 +220,22 @@ def freeze_with(self, client):
197220
Raises:
198221
Exception: If required IDs are not set.
199222
"""
200-
if self.transaction_body_bytes is not None:
223+
if self._transaction_body_bytes:
201224
return self
202-
225+
203226
if self.transaction_id is None:
204227
self.transaction_id = client.generate_transaction_id()
205-
206-
if self.node_account_id is None:
207-
self.node_account_id = client.network.current_node._account_id
208-
209-
# print(f"Transaction's node account ID set to: {self.node_account_id}")
210-
self.transaction_body_bytes = self.build_transaction_body().SerializeToString()
211-
228+
229+
# We iterate through every node in the client's network
230+
# For each node, set the node_account_id and build the transaction body
231+
# This allows the transaction to be submitted to any node in the network
232+
for node in client.network.nodes:
233+
self.node_account_id = node._account_id
234+
self._transaction_body_bytes[node._account_id] = self.build_transaction_body().SerializeToString()
235+
236+
# Set the node account id to the current node in the network
237+
self.node_account_id = client.network.current_node._account_id
238+
212239
return self
213240

214241
def execute(self, client):
@@ -228,7 +255,7 @@ def execute(self, client):
228255
MaxAttemptsError: If the transaction/query fails after the maximum number of attempts
229256
ReceiptStatusError: If the query fails with a receipt status error
230257
"""
231-
if self.transaction_body_bytes is None:
258+
if not self._transaction_body_bytes:
232259
self.freeze_with(client)
233260

234261
if self.operator_account_id is None:
@@ -257,8 +284,13 @@ def is_signed_by(self, public_key):
257284
bool: True if signed by the given public key, False otherwise.
258285
"""
259286
public_key_bytes = public_key.to_bytes_raw()
260-
261-
for sig_pair in self.signature_map.sigPair:
287+
288+
sig_map = self._signature_map.get(self._transaction_body_bytes.get(self.node_account_id))
289+
290+
if sig_map is None:
291+
return False
292+
293+
for sig_pair in sig_map.sigPair:
262294
if sig_pair.pubKeyPrefix == public_key_bytes:
263295
return True
264296
return False
@@ -317,7 +349,7 @@ def _require_not_frozen(self):
317349
Raises:
318350
Exception: If the transaction has already been frozen.
319351
"""
320-
if self.transaction_body_bytes is not None:
352+
if self._transaction_body_bytes:
321353
raise Exception("Transaction is immutable; it has been frozen.")
322354

323355
def _require_frozen(self):
@@ -330,7 +362,7 @@ def _require_frozen(self):
330362
Raises:
331363
Exception: If the transaction has not been frozen yet.
332364
"""
333-
if self.transaction_body_bytes is None:
365+
if not self._transaction_body_bytes:
334366
raise Exception("Transaction is not frozen")
335367

336368
def set_transaction_memo(self, memo):

tests/unit/test_account_create_transaction.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def test_account_create_transaction_build(mock_account_ids):
5555
assert transaction_body.cryptoCreateAccount.initialBalance == 100000000
5656
assert transaction_body.cryptoCreateAccount.memo == "Test account"
5757

58-
# This test uses fixture mock_account_ids as parameter
59-
def test_account_create_transaction_sign(mock_account_ids):
58+
# This test uses fixture (mock_account_ids, mock_client) as parameter
59+
def test_account_create_transaction_sign(mock_account_ids, mock_client):
6060
"""Test signing the account create transaction."""
6161
operator_id, node_account_id = mock_account_ids
6262

@@ -72,12 +72,23 @@ def test_account_create_transaction_sign(mock_account_ids):
7272
)
7373
account_tx.transaction_id = generate_transaction_id(operator_id)
7474
account_tx.node_account_id = node_account_id
75-
account_tx.freeze_with(None)
76-
account_tx.sign(operator_private_key)
75+
account_tx.freeze_with(mock_client)
7776

78-
# Verify signature was added
79-
assert len(account_tx.signature_map.sigPair) == 1, \
77+
# Add first signiture
78+
account_tx.sign(mock_client.operator_private_key)
79+
body_bytes = account_tx._transaction_body_bytes[node_account_id]
80+
81+
assert body_bytes in account_tx._signature_map, "Body bytes should be a key in the signature map dictionary"
82+
assert len(account_tx._signature_map[body_bytes].sigPair) == 1, \
8083
"Transaction should have exactly one signature"
84+
85+
# Add second signiture
86+
account_tx.sign(operator_private_key)
87+
body_bytes = account_tx._transaction_body_bytes[node_account_id]
88+
89+
assert body_bytes in account_tx._signature_map, "Body bytes should be a key in the signature map dictionary"
90+
assert len(account_tx._signature_map[body_bytes].sigPair) == 2, \
91+
"Transaction should have exactly two signatures"
8192

8293
def test_account_create_transaction():
8394
"""Integration test for AccountCreateTransaction with retry and response handling."""

tests/unit/test_executable.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,51 @@ def test_topic_create_transaction_fails_on_nonretriable_error():
380380

381381
# Verify the error contains the expected status
382382
assert str(ResponseCode.INVALID_TRANSACTION_BODY) in str(exc_info.value)
383+
384+
def test_transaction_node_switching_body_bytes():
385+
"""Test that execution switches nodes after receiving a non-retriable error."""
386+
ok_response = TransactionResponseProto(nodeTransactionPrecheckCode=ResponseCode.OK)
387+
error = RealRpcError(grpc.StatusCode.UNAVAILABLE, "Test error")
388+
389+
receipt_response = response_pb2.Response(
390+
transactionGetReceipt=transaction_get_receipt_pb2.TransactionGetReceiptResponse(
391+
header=response_header_pb2.ResponseHeader(
392+
nodeTransactionPrecheckCode=ResponseCode.OK
393+
),
394+
receipt=transaction_receipt_pb2.TransactionReceipt(
395+
status=ResponseCode.SUCCESS
396+
)
397+
)
398+
)
399+
# First node gives error, second node gives OK, third node gives error
400+
response_sequences = [
401+
[error],
402+
[ok_response, receipt_response],
403+
]
404+
405+
with mock_hedera_servers(response_sequences) as client, patch('time.sleep'):
406+
# We set the current node to 0
407+
client.network._node_index = 0
408+
client.network.current_node = client.network.nodes[0]
409+
410+
transaction = (
411+
AccountCreateTransaction()
412+
.set_key(PrivateKey.generate().public_key())
413+
.set_initial_balance(100_000_000)
414+
.freeze_with(client)
415+
.sign(client.operator_private_key)
416+
)
417+
418+
for node in client.network.nodes:
419+
assert transaction._transaction_body_bytes.get(node._account_id) is not None, "Transaction body bytes should be set for all nodes"
420+
sig_map = transaction._signature_map.get(transaction._transaction_body_bytes[node._account_id])
421+
assert sig_map is not None, "Signature map should be set for all nodes"
422+
assert len(sig_map.sigPair) == 1, "Signature map should have one signature"
423+
assert sig_map.sigPair[0].pubKeyPrefix == client.operator_private_key.public_key().to_bytes_raw(), "Signature should be for the operator"
424+
425+
try:
426+
transaction.execute(client)
427+
except (Exception, grpc.RpcError) as e:
428+
pytest.fail(f"Transaction execution should not raise an exception, but raised: {e}")
429+
# Verify we're now on the second node
430+
assert client.network.current_node._account_id == AccountId(0, 0, 4), "Client should have switched to the second node"

tests/unit/test_token_associate_transaction.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,13 @@ def test_sign_transaction(mock_account_ids, mock_client):
6969

7070
# Sign the transaction
7171
associate_tx.sign(private_key)
72+
73+
node_id = mock_client.network.current_node._account_id
74+
body_bytes = associate_tx._transaction_body_bytes[node_id]
7275

73-
assert len(associate_tx.signature_map.sigPair) == 1
74-
sig_pair = associate_tx.signature_map.sigPair[0]
76+
assert body_bytes in associate_tx._signature_map, "Body bytes should be a key in the signature map dictionary"
77+
assert len(associate_tx._signature_map[body_bytes].sigPair) == 1
78+
sig_pair = associate_tx._signature_map[body_bytes].sigPair[0]
7579

7680
assert sig_pair.pubKeyPrefix == b'public_key'
7781
assert sig_pair.ed25519 == b'signature'

tests/unit/test_token_create_transaction.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,23 @@ def test_sign_transaction(mock_account_ids, mock_client):
276276
# Sign with both sign keys
277277
token_tx.sign(private_key) # Necessary
278278
token_tx.sign(private_key_admin) # Since admin key exists
279+
280+
node_id = mock_client.network.current_node._account_id
281+
body_bytes = token_tx._transaction_body_bytes[node_id]
279282

280283
# Expect 2 sigPairs
281-
assert len(token_tx.signature_map.sigPair) == 2
284+
assert len(token_tx._signature_map[body_bytes].sigPair) == 2
282285

283-
sig_pair = token_tx.signature_map.sigPair[0]
286+
sig_pair = token_tx._signature_map[body_bytes].sigPair[0]
284287
assert sig_pair.pubKeyPrefix == b"public_key"
285288
assert sig_pair.ed25519 == b"signature"
286289

287-
sig_pair_admin = token_tx.signature_map.sigPair[1]
290+
sig_pair_admin = token_tx._signature_map[body_bytes].sigPair[1]
288291
assert sig_pair_admin.pubKeyPrefix == b"admin_public_key"
289292
assert sig_pair_admin.ed25519 == b"admin_signature"
290293

291294
# Confirm that neither sigPair belongs to supply_key or freeze_key:
292-
for sig_pair in token_tx.signature_map.sigPair:
295+
for sig_pair in token_tx._signature_map[body_bytes].sigPair:
293296
assert sig_pair.pubKeyPrefix not in (b"supply_public_key", b"freeze_public_key")
294297

295298
# This test uses fixture (mock_account_ids, mock_client) as parameter
@@ -445,7 +448,7 @@ def test_transaction_execution_failure(mock_account_ids):
445448
token_tx.transaction_id = generate_transaction_id(treasury_account)
446449

447450
# Set the transaction body bytes to avoid calling build_transaction_body
448-
token_tx.transaction_body_bytes = b"mock_body_bytes"
451+
token_tx._transaction_body_bytes = b"mock_body_bytes"
449452

450453
# Mock the client and its operator_private_key
451454
token_tx.client = MagicMock()
@@ -543,8 +546,8 @@ def test_overwrite_defaults(mock_account_ids, mock_client):
543546
# Confirm no adminKey was set
544547
assert not tx_body.tokenCreation.HasField("adminKey")
545548

546-
# This test uses fixture mock_account_ids as parameter
547-
def test_transaction_freeze_prevents_modification(mock_account_ids):
549+
# This test uses fixture (mock_account_ids, mock_client) as parameter
550+
def test_transaction_freeze_prevents_modification(mock_account_ids, mock_client):
548551
"""
549552
Test that after freeze() is called, attempts to modify TokenCreateTransaction
550553
parameters raise an exception indicating immutability.
@@ -562,10 +565,9 @@ def test_transaction_freeze_prevents_modification(mock_account_ids):
562565

563566
transaction.node_account_id = node_account_id
564567
transaction.transaction_id = generate_transaction_id(treasury_account)
565-
transaction.client = MagicMock()
566568

567569
# Freeze the transaction
568-
transaction.freeze_with(transaction.client)
570+
transaction.freeze_with(mock_client)
569571

570572
# Attempt to overwrite after freeze - expect exceptions
571573
with pytest.raises(Exception, match="Transaction is immutable; it has been frozen."):

tests/unit/test_token_delete_transaction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ def test_sign_transaction(mock_account_ids, mock_client):
6060

6161
delete_tx.sign(private_key)
6262

63-
assert len(delete_tx.signature_map.sigPair) == 1
64-
sig_pair = delete_tx.signature_map.sigPair[0]
63+
node_id = mock_client.network.current_node._account_id
64+
body_bytes = delete_tx._transaction_body_bytes[node_id]
65+
66+
assert len(delete_tx._signature_map[body_bytes].sigPair) == 1
67+
sig_pair = delete_tx._signature_map[body_bytes].sigPair[0]
6568
assert sig_pair.pubKeyPrefix == b'public_key'
6669
assert sig_pair.ed25519 == b'signature'
6770

tests/unit/test_token_dissociate_transaction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,11 @@ def test_sign_transaction(mock_account_ids, mock_client):
8989

9090
dissociate_tx.sign(private_key)
9191

92-
assert len(dissociate_tx.signature_map.sigPair) == 1
93-
sig_pair = dissociate_tx.signature_map.sigPair[0]
92+
node_id = mock_client.network.current_node._account_id
93+
body_bytes = dissociate_tx._transaction_body_bytes[node_id]
94+
95+
assert len(dissociate_tx._signature_map[body_bytes].sigPair) == 1
96+
sig_pair = dissociate_tx._signature_map[body_bytes].sigPair[0]
9497

9598
assert sig_pair.pubKeyPrefix == b'public_key'
9699
assert sig_pair.ed25519 == b'signature'

tests/unit/test_token_freeze_transaction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ def test_sign_transaction(mock_account_ids, mock_client):
7979

8080
freeze_tx.sign(freeze_key)
8181

82-
assert len(freeze_tx.signature_map.sigPair) == 1
83-
sig_pair = freeze_tx.signature_map.sigPair[0]
82+
node_id = mock_client.network.current_node._account_id
83+
body_bytes = freeze_tx._transaction_body_bytes[node_id]
84+
85+
assert len(freeze_tx._signature_map[body_bytes].sigPair) == 1
86+
sig_pair = freeze_tx._signature_map[body_bytes].sigPair[0]
8487
assert sig_pair.pubKeyPrefix == b'public_key'
8588
assert sig_pair.ed25519 == b'signature'
8689

0 commit comments

Comments
 (0)