Skip to content

Commit 67f1a30

Browse files
committed
add tests for track model
1 parent fa7c7ce commit 67f1a30

File tree

9 files changed

+217
-43
lines changed

9 files changed

+217
-43
lines changed

antares-python/src/antares/cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import logging
4+
from typing import NoReturn
45

56
import typer
67
from rich.console import Console
@@ -15,7 +16,7 @@
1516
console = Console(theme=Theme({"info": "green", "warn": "yellow", "error": "bold red"}))
1617

1718

18-
def handle_error(message: str, code: int, json_output: bool = False):
19+
def handle_error(message: str, code: int, json_output: bool = False) -> NoReturn:
1920
logger = logging.getLogger("antares.cli")
2021
if json_output:
2122
typer.echo(json.dumps({"error": message}), err=True)
@@ -50,7 +51,7 @@ def reset(
5051
config: str = typer.Option(None),
5152
verbose: bool = typer.Option(False, "--verbose", "-v"),
5253
json_output: bool = typer.Option(False, "--json", help="Output in JSON format"),
53-
):
54+
) -> None:
5455
client = build_client(config, verbose, json_output)
5556
try:
5657
client.reset_simulation()
@@ -67,7 +68,7 @@ def add_ship(
6768
config: str = typer.Option(None, help="Path to the configuration file"),
6869
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"),
6970
json_output: bool = typer.Option(False, "--json", help="Output in JSON format"),
70-
):
71+
) -> None:
7172
client = build_client(config, verbose, json_output)
7273
try:
7374
ship = ShipConfig(initial_position=(x, y))
@@ -84,13 +85,13 @@ def subscribe(
8485
verbose: bool = typer.Option(False, "--verbose", "-v"),
8586
json_output: bool = typer.Option(False, "--json", help="Output in JSON format"),
8687
log_file: str = typer.Option("antares.log", help="Path to log file"),
87-
):
88+
) -> None:
8889
setup_logging(log_file=log_file, level=logging.DEBUG if verbose else logging.INFO)
8990
logger = logging.getLogger("antares.cli")
9091

9192
client = build_client(config, verbose, json_output)
9293

93-
async def _sub():
94+
async def _sub() -> None:
9495
try:
9596
async for event in client.subscribe():
9697
if json_output:
Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
11
from collections.abc import AsyncIterator
2+
from typing import Any
23

34
from antares.client.rest import RestClient
45
from antares.client.tcp import TCPSubscriber
56
from antares.config import AntaresSettings
67
from antares.models.ship import ShipConfig
8+
from antares.models.track import Track
79

810

911
class AntaresClient:
1012
def __init__(
1113
self,
12-
base_url: str | None = None,
13-
tcp_host: str | None = None,
14-
tcp_port: int | None = None,
15-
timeout: float | None = None,
16-
auth_token: str | None = None,
14+
**kwargs: Any,
1715
) -> None:
1816
"""
1917
Public interface for interacting with the Antares simulation engine.
2018
Accepts config overrides directly or falls back to environment-based configuration.
2119
"""
2220

23-
overrides = {
24-
"base_url": base_url,
25-
"tcp_host": tcp_host,
26-
"tcp_port": tcp_port,
27-
"timeout": timeout,
28-
"auth_token": auth_token,
29-
}
30-
clean_overrides = {k: v for k, v in overrides.items() if v is not None}
21+
# Only include kwargs that match AntaresSettings fields
22+
valid_fields = AntaresSettings.model_fields.keys()
23+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields and v is not None}
3124

3225
# Merge provided arguments with environment/.env via AntaresSettings
33-
self._settings = AntaresSettings(**clean_overrides)
26+
self._settings = AntaresSettings(**filtered_kwargs)
3427

3528
self._rest = RestClient(
3629
base_url=self._settings.base_url,
@@ -51,12 +44,12 @@ def add_ship(self, ship: ShipConfig) -> None:
5144
"""
5245
return self._rest.add_ship(ship)
5346

54-
async def subscribe(self) -> AsyncIterator[dict]:
47+
async def subscribe(self) -> AsyncIterator[Track]:
5548
"""
5649
Subscribes to live simulation data over TCP.
5750
5851
Yields:
59-
Parsed simulation event data as dictionaries.
52+
Parsed simulation event data as Track objects.
6053
"""
6154
async for event in self._tcp.subscribe():
6255
yield event

antares-python/src/antares/client/tcp.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import AsyncIterator
55

66
from antares.errors import SubscriptionError
7+
from antares.models.track import Track
78

89
logger = logging.getLogger(__name__)
910

@@ -26,25 +27,27 @@ def __init__(self, host: str, port: int, reconnect: bool = True) -> None:
2627
self.port = port
2728
self.reconnect = reconnect
2829

29-
async def subscribe(self) -> AsyncIterator[dict]:
30+
async def subscribe(self) -> AsyncIterator[Track]:
3031
"""
31-
Connects to the TCP server and yields simulation events as parsed dictionaries.
32+
Connects to the TCP server and yields simulation events as Track objects.
3233
This is an infinite async generator until disconnected or cancelled.
3334
3435
Yields:
35-
Parsed simulation events.
36+
Parsed simulation events as Track objects.
3637
"""
3738
while True:
3839
try:
3940
reader, _ = await asyncio.open_connection(self.host, self.port)
4041
while not reader.at_eof():
4142
line = await reader.readline()
4243
if line:
43-
yield json.loads(line.decode())
44+
track = Track.from_csv_row(line.decode())
45+
yield track
4446
except (
4547
ConnectionRefusedError,
4648
asyncio.IncompleteReadError,
4749
json.JSONDecodeError,
50+
ValueError,
4851
) as e:
4952
logger.error("TCP stream error: %s", e)
5053
if not self.reconnect:

antares-python/src/antares/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ class AntaresSettings(BaseSettings):
77
Supports environment variables and `.env` file loading.
88
"""
99

10-
base_url: str
10+
base_url: str = "http://localhost:8000"
1111
tcp_host: str = "localhost"
1212
tcp_port: int = 9000
1313
timeout: float = 5.0
1414
auth_token: str | None = None
1515

1616
model_config = SettingsConfigDict(
1717
env_file=".env",
18-
env_prefix="ANTARES_",
18+
env_prefix="antares_",
1919
case_sensitive=False,
2020
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import ClassVar
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class Track(BaseModel):
7+
id: int
8+
year: int
9+
month: int
10+
day: int
11+
hour: int
12+
minute: int
13+
second: int
14+
millisecond: int
15+
stat: str
16+
type_: str = Field(alias="type") # maps from "type" input
17+
name: str
18+
linemask: int
19+
size: int
20+
range: float
21+
azimuth: float
22+
lat: float
23+
long: float
24+
speed: float
25+
course: float
26+
quality: int
27+
l16quality: int
28+
lacks: int
29+
winrgw: int
30+
winazw: float
31+
stderr: float
32+
33+
# expected order of fields from TCP stream
34+
__field_order__: ClassVar[list[str]] = [
35+
"id",
36+
"year",
37+
"month",
38+
"day",
39+
"hour",
40+
"minute",
41+
"second",
42+
"millisecond",
43+
"stat",
44+
"type_",
45+
"name",
46+
"linemask",
47+
"size",
48+
"range",
49+
"azimuth",
50+
"lat",
51+
"long",
52+
"speed",
53+
"course",
54+
"quality",
55+
"l16quality",
56+
"lacks",
57+
"winrgw",
58+
"winazw",
59+
"stderr",
60+
]
61+
62+
@classmethod
63+
def from_csv_row(cls, line: str) -> "Track":
64+
parts = line.strip().split(",")
65+
if len(parts) != len(cls.__field_order__):
66+
raise ValueError(f"Expected {len(cls.__field_order__)} fields, got {len(parts)}")
67+
68+
converted = {}
69+
for field_name, value in zip(cls.__field_order__, parts, strict=True):
70+
field_info = cls.model_fields[field_name]
71+
field_type = field_info.annotation
72+
73+
if field_type is None:
74+
raise ValueError(f"Field '{field_name}' has no type annotation")
75+
76+
# Use alias if defined
77+
key = field_info.alias or field_name
78+
try:
79+
# We trust simple coercion here; Pydantic will do final validation
80+
converted[key] = field_type(value)
81+
except Exception as e:
82+
raise ValueError(f"Invalid value for field '{field_name}': {value} ({e})") from e
83+
84+
return cls(**converted)

antares-python/template.env

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ANTARES_BASE_URL = "http://localhost:8080"

antares-python/tests/client/test_tcp.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
from antares.client.tcp import TCPSubscriber
66
from antares.errors import SubscriptionError
7+
from antares.models.track import Track
78

89

910
@pytest.mark.asyncio
10-
async def test_subscribe_success(monkeypatch):
11-
# Simulated lines returned from the TCP stream
12-
lines = [b'{"event": "ok"}\n', b'{"event": "done"}\n', b""]
11+
async def test_subscribe_success(monkeypatch, sample_track_line):
12+
# Simulated CSV lines returned from the TCP stream
13+
encoded_lines = [sample_track_line.encode() + b"\n", b""]
1314

1415
async def fake_readline():
15-
return lines.pop(0)
16+
return encoded_lines.pop(0)
1617

1718
# Simulate EOF after all lines are read
18-
eof_flags = [False, False, True]
19+
eof_flags = [False, True]
1920

2021
fake_reader = AsyncMock()
2122
fake_reader.readline = AsyncMock(side_effect=fake_readline)
@@ -27,7 +28,11 @@ async def fake_readline():
2728
subscriber = TCPSubscriber("localhost", 1234, reconnect=False)
2829

2930
events = [event async for event in subscriber.subscribe()]
30-
assert events == [{"event": "ok"}, {"event": "done"}]
31+
expected_lat = -33.45
32+
assert len(events) == 1
33+
assert isinstance(events[0], Track)
34+
assert events[0].name == "Eagle-1"
35+
assert events[0].lat == expected_lat
3136

3237

3338
@pytest.mark.asyncio
@@ -41,7 +46,7 @@ async def test_subscribe_failure(monkeypatch):
4146

4247

4348
@pytest.mark.asyncio
44-
async def test_subscribe_reconnects_on_failure(monkeypatch):
49+
async def test_subscribe_reconnects_on_failure(monkeypatch, sample_track_line):
4550
class OneMessageReader:
4651
def __init__(self):
4752
self.called = False
@@ -52,7 +57,7 @@ def at_eof(self):
5257
async def readline(self):
5358
if not self.called:
5459
self.called = True
55-
return b'{"event": "recovered"}\n'
60+
return sample_track_line.encode() + b"\n"
5661
return b""
5762

5863
open_calls = []
@@ -71,6 +76,65 @@ async def fake_open_connection(host, port):
7176
events = []
7277
async for event in subscriber.subscribe():
7378
events.append(event)
74-
break # Exit after one event
79+
break # exit after first track
7580

76-
assert events == [{"event": "recovered"}]
81+
assert len(events) == 1
82+
assert isinstance(events[0], Track)
83+
assert events[0].name == "Eagle-1"
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_subscribe_invalid_field_count(monkeypatch):
88+
invalid_line = "1,2025,4,11"
89+
90+
async def fake_readline():
91+
return invalid_line.encode() + b"\n"
92+
93+
fake_reader = AsyncMock()
94+
fake_reader.readline = AsyncMock(side_effect=fake_readline)
95+
fake_reader.at_eof = MagicMock(side_effect=[False, True])
96+
97+
monkeypatch.setattr("asyncio.open_connection", AsyncMock(return_value=(fake_reader, None)))
98+
99+
subscriber = TCPSubscriber("localhost", 1234, reconnect=False)
100+
101+
with pytest.raises(SubscriptionError) as excinfo:
102+
async for _ in subscriber.subscribe():
103+
pass
104+
105+
assert "Expected 25 fields" in str(excinfo.value)
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_subscribe_invalid_value(monkeypatch, sample_track_line):
110+
bad_line = sample_track_line.replace("1,", "bad_id,", 1)
111+
112+
async def fake_readline():
113+
return bad_line.encode() + b"\n"
114+
115+
fake_reader = AsyncMock()
116+
fake_reader.readline = AsyncMock(side_effect=fake_readline)
117+
fake_reader.at_eof = MagicMock(side_effect=[False, True])
118+
119+
monkeypatch.setattr("asyncio.open_connection", AsyncMock(return_value=(fake_reader, None)))
120+
121+
subscriber = TCPSubscriber("localhost", 1234, reconnect=False)
122+
123+
with pytest.raises(SubscriptionError) as excinfo:
124+
async for _ in subscriber.subscribe():
125+
pass
126+
127+
assert "Invalid value for field 'id'" in str(excinfo.value)
128+
129+
130+
def test_from_csv_field_type_none(monkeypatch):
131+
class FakeTrack(Track):
132+
__field_order__ = ["id"]
133+
id: int
134+
135+
FakeTrack.model_fields["id"].annotation = None
136+
137+
with pytest.raises(ValueError) as excinfo:
138+
FakeTrack.from_csv_row("123")
139+
140+
assert "has no type annotation" in str(excinfo.value)

0 commit comments

Comments
 (0)