From 23e7aaf3f19393e48a7981adbbda8a2866eb21b1 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 7 Jul 2025 16:59:12 -0700 Subject: [PATCH] ruff issues --- chia/_tests/plot_sync/test_sender.py | 11 +++++---- chia/_tests/util/test_network.py | 33 ++++++++++++++++++++++++++- chia/cmds/configure.py | 29 +++++++++-------------- chia/cmds/peer_funcs.py | 18 +++++++-------- chia/simulator/block_tools.py | 3 ++- chia/util/network.py | 9 ++++++++ chia/util/virtual_project_analysis.py | 2 +- 7 files changed, 69 insertions(+), 36 deletions(-) diff --git a/chia/_tests/plot_sync/test_sender.py b/chia/_tests/plot_sync/test_sender.py index 21b9671a479a..ea011082ce5f 100644 --- a/chia/_tests/plot_sync/test_sender.py +++ b/chia/_tests/plot_sync/test_sender.py @@ -14,6 +14,7 @@ from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse from chia.protocols.outbound_message import NodeType from chia.protocols.protocol_message_types import ProtocolMessageTypes +from chia.server.ws_connection import WSChiaConnection from chia.simulator.block_tools import BlockTools @@ -37,11 +38,11 @@ def test_set_connection_values(bt: BlockTools, seeded_random: random.Random) -> # Test invalid NodeType values for connection_type in NodeType: if connection_type != NodeType.FARMER: - pytest.raises( - InvalidConnectionTypeError, - sender.set_connection, - get_dummy_connection(connection_type, farmer_connection.peer_node_id), - ) + with pytest.raises(InvalidConnectionTypeError): + dummy_connection: WSChiaConnection = get_dummy_connection( + connection_type, farmer_connection.peer_node_id + ) # type: ignore[assignment] + sender.set_connection(dummy_connection) # Test setting a valid connection works sender.set_connection(farmer_connection) # type:ignore[arg-type] assert sender._connection is not None diff --git a/chia/_tests/util/test_network.py b/chia/_tests/util/test_network.py index 96d254ff96ff..3686bed1b526 100644 --- a/chia/_tests/util/test_network.py +++ b/chia/_tests/util/test_network.py @@ -8,7 +8,38 @@ import pytest from chia.util.ip_address import IPAddress -from chia.util.network import resolve +from chia.util.network import parse_host_port, resolve + + +@pytest.mark.parametrize( + "host_port, expected_host, expected_port", + [ + ("127.0.0.1:8080", "127.0.0.1", 8080), + ("[::1]:8080", "::1", 8080), + ("example.com:8080", "example.com", 8080), + ("localhost:8555", "localhost", 8555), + ], +) +def test_parse_host_port(host_port: str, expected_host: str, expected_port: int) -> None: + host, port = parse_host_port(host_port) + assert host == expected_host + assert port == expected_port + + +@pytest.mark.parametrize( + "host_port", + [ + "127.0.0.1", + "::1", + "example.com", + "localhost", + "127.0.0.1:", + ":8080", + ], +) +def test_parse_host_port_invalid(host_port: str) -> None: + with pytest.raises(ValueError): + parse_host_port(host_port) @pytest.mark.anyio diff --git a/chia/cmds/configure.py b/chia/cmds/configure.py index 6dc140eb61f4..4af54772098e 100644 --- a/chia/cmds/configure.py +++ b/chia/cmds/configure.py @@ -16,6 +16,7 @@ save_config, str2bool, ) +from chia.util.network import parse_host_port def configure( @@ -42,28 +43,20 @@ def configure( change_made = False if set_node_introducer: try: - if set_node_introducer.index(":"): - host, port = ( - ":".join(set_node_introducer.split(":")[:-1]), - set_node_introducer.split(":")[-1], - ) - config["full_node"]["introducer_peer"]["host"] = host - config["full_node"]["introducer_peer"]["port"] = int(port) - config["introducer"]["port"] = int(port) - print("Node introducer updated") - change_made = True + host, port = parse_host_port(set_node_introducer) + config["full_node"]["introducer_peer"]["host"] = host + config["full_node"]["introducer_peer"]["port"] = port + config["introducer"]["port"] = port + print("Node introducer updated") + change_made = True except ValueError: print("Node introducer address must be in format [IP:Port]") if set_farmer_peer: try: - if set_farmer_peer.index(":"): - host, port = ( - ":".join(set_farmer_peer.split(":")[:-1]), - set_farmer_peer.split(":")[-1], - ) - set_peer_info(config["harvester"], peer_type=NodeType.FARMER, peer_host=host, peer_port=int(port)) - print("Farmer peer updated, make sure your harvester has the proper cert installed") - change_made = True + host, port = parse_host_port(set_farmer_peer) + set_peer_info(config["harvester"], peer_type=NodeType.FARMER, peer_host=host, peer_port=port) + print("Farmer peer updated, make sure your harvester has the proper cert installed") + change_made = True except ValueError: print("Farmer address must be in format [IP:Port]") if set_fullnode_port: diff --git a/chia/cmds/peer_funcs.py b/chia/cmds/peer_funcs.py index 76997c6f4a3f..7e37bb61aed0 100644 --- a/chia/cmds/peer_funcs.py +++ b/chia/cmds/peer_funcs.py @@ -5,24 +5,22 @@ from chia.cmds.cmds_util import NODE_TYPES, get_any_service_client from chia.rpc.rpc_client import RpcClient +from chia.util.network import parse_host_port async def add_node_connection(rpc_client: RpcClient, add_connection: str) -> None: - if ":" not in add_connection: - print("Enter a valid IP and port in the following format: 10.5.4.3:8000") - else: - ip, port = ( - ":".join(add_connection.split(":")[:-1]), - add_connection.split(":")[-1], - ) - print(f"Connecting to {ip}, {port}") + try: + host, port = parse_host_port(add_connection) + print(f"Connecting to {host}, {port}") try: - result = await rpc_client.open_connection(ip, int(port)) + result = await rpc_client.open_connection(host, port) err = result.get("error") if result["success"] is False or err is not None: print(err) except Exception: - print(f"Failed to connect to {ip}:{port}") + print(f"Failed to connect to {host}:{port}") + except ValueError: + print("Enter a valid IP and port in the following format: 10.5.4.3:8000") async def remove_node_connection(rpc_client: RpcClient, remove_connection: str) -> None: diff --git a/chia/simulator/block_tools.py b/chia/simulator/block_tools.py index f9a981930fcc..b129b5fb251f 100644 --- a/chia/simulator/block_tools.py +++ b/chia/simulator/block_tools.py @@ -1734,7 +1734,8 @@ def get_plot_dir(plot_dir_name: str = "test-plots", automated_testing: bool = Tr if not automated_testing: # make sure we don't accidentally stack directories. root_dir = ( root_dir.parent - if root_dir.parts[-1] == plot_dir_name.split("/")[0] or root_dir.parts[-1] == plot_dir_name.split("\\")[0] + if root_dir.parts[-1] == plot_dir_name.split("/", maxsplit=1)[0] + or root_dir.parts[-1] == plot_dir_name.split("\\", maxsplit=1)[0] else root_dir ) cache_path = root_dir.joinpath(plot_dir_name) diff --git a/chia/util/network.py b/chia/util/network.py index 054843f6863f..1c98e165baea 100644 --- a/chia/util/network.py +++ b/chia/util/network.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from ipaddress import IPv4Network, IPv6Network, ip_address from typing import Any, Literal, Optional, Union +from urllib.parse import urlsplit from aiohttp import web from aiohttp.log import web_logger @@ -20,6 +21,14 @@ from chia.util.task_referencer import create_referenced_task +def parse_host_port(host_port: str) -> tuple[str, int]: + """Parse a host:port string into a tuple of (host, port), raising ValueError on failure.""" + result = urlsplit(f"//{host_port}") + if result.hostname and result.port: + return result.hostname, result.port + raise ValueError(f"Invalid host:port string: {host_port}") + + @final @dataclass class WebServer: diff --git a/chia/util/virtual_project_analysis.py b/chia/util/virtual_project_analysis.py index bc0a29c88ad2..f861768b2870 100644 --- a/chia/util/virtual_project_analysis.py +++ b/chia/util/virtual_project_analysis.py @@ -296,7 +296,7 @@ def parse_file_or_package(identifier: str) -> FileOrPackage: if "(" not in identifier: return File(Path(identifier)) else: - return File(Path(identifier.split("(")[0].strip())) + return File(Path(identifier.split("(", maxsplit=1)[0].strip())) if ".py" not in identifier and identifier[0] == "(" and identifier[-1] == ")": return Package(identifier[1:-1]) # strip parens