Skip to content

Commit 328f369

Browse files
improve: new constructors and update conform to pydantic 2
1 parent 2ff7eb1 commit 328f369

File tree

8 files changed

+5121
-19
lines changed

8 files changed

+5121
-19
lines changed

kmm/functional_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from pydantic import BaseModel, Extra
1+
from pydantic import BaseModel, ConfigDict
22

33

44
class FunctionalBase(BaseModel):
5-
class Config:
6-
allow_mutation = False
7-
extra = Extra.forbid
5+
model_config = ConfigDict(frozen=True, extra="forbid")
86

97
def map(self, fn, *args, **kwargs):
108
return fn(self, *args, **kwargs)
119

1210
def replace(self, **kwargs):
13-
new_dict = self.dict()
11+
new_dict = self.model_dump()
1412
new_dict.update(**kwargs)
1513
return type(self)(**new_dict)

kmm/header/header.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from xml.etree import ElementTree
33

4-
from pydantic import validate_arguments
4+
from pydantic import validate_call
55

66
import kmm
77

@@ -12,7 +12,7 @@ class Header(kmm.FunctionalBase):
1212
sync: int
1313

1414
@staticmethod
15-
@validate_arguments
15+
@validate_call
1616
def from_path(path: Path, raise_on_malformed_data: bool = True):
1717
"""
1818
Loads header data from .hdr file.

kmm/positions/positions.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22

33
import pandas as pd
4-
from pydantic import validate_arguments
4+
from pydantic import ConfigDict, validate_call
55

66
import kmm
77
from kmm.header.header import Header
@@ -10,11 +10,10 @@
1010
class Positions(kmm.FunctionalBase):
1111
dataframe: pd.DataFrame
1212

13-
class Config:
14-
arbitrary_types_allowed = True
13+
model_config = ConfigDict(arbitrary_types_allowed=True)
1514

1615
@staticmethod
17-
@validate_arguments
16+
@validate_call
1817
def from_path(
1918
path: Path,
2019
raise_on_malformed_data: bool = True,
@@ -37,7 +36,59 @@ def from_path(
3736
return Positions(dataframe=dataframe)
3837

3938
@staticmethod
40-
@validate_arguments
39+
@validate_call
40+
def read_sync_adjust_from_header_path(
41+
header_path: Path,
42+
kmm2=True,
43+
raise_on_malformed_data: bool = True,
44+
replace_commas: bool = True,
45+
):
46+
"""
47+
Convenience method to load positions from a header file, assumes a kmm file in the same directory.
48+
If kmm2 is True, the method will load a kmm2 file, otherwise a kmm file.
49+
"""
50+
kmm_stem = (
51+
header_path.stem.replace("owlsbtlpos", "").split("_2011T")[0] + "_2011T"
52+
)
53+
if kmm2:
54+
kmm_path = header_path.parent / f"{kmm_stem}.kmm2"
55+
else:
56+
kmm_path = header_path.parent / f"{kmm_stem}.kmm"
57+
return Positions.read_sync_adjust(
58+
kmm_path,
59+
header_path,
60+
raise_on_malformed_data=raise_on_malformed_data,
61+
replace_commas=replace_commas,
62+
)
63+
64+
@staticmethod
65+
@validate_call
66+
def read_sync_adjust_from_measurement_name(
67+
measurement_name: str,
68+
input_dir: Path,
69+
kmm2: bool = True,
70+
raise_on_malformed_data: bool = True,
71+
replace_commas: bool = True,
72+
):
73+
"""
74+
Convenience method to load positions from a measurement name, assumes a kmm2 file and a header file in input_dir.
75+
"""
76+
timestamp = "_".join(measurement_name.split("_")[:-1])
77+
part = measurement_name.split("_")[-1]
78+
if kmm2:
79+
kmm_path = input_dir / f"{timestamp}_2011T.kmm2"
80+
else:
81+
kmm_path = input_dir / f"{timestamp}_2011T.kmm"
82+
header_path = input_dir / f"owlsbtlpos{timestamp}_2011T{part}.hdr"
83+
return Positions.read_sync_adjust(
84+
kmm_path,
85+
header_path,
86+
raise_on_malformed_data=raise_on_malformed_data,
87+
replace_commas=replace_commas,
88+
)
89+
90+
@staticmethod
91+
@validate_call
4192
def read_sync_adjust(
4293
kmm_path: Path,
4394
header_path: Path,
@@ -60,7 +111,7 @@ def read_sync_adjust(
60111
.geodetic()
61112
)
62113

63-
@validate_arguments
114+
@validate_call
64115
def sync_frame_index(
65116
self,
66117
header: Header,
@@ -97,3 +148,17 @@ def test_empty_kmm():
97148
def test_empty_kmm2():
98149
positions = Positions.from_path("tests/empty.kmm2")
99150
assert len(positions.dataframe) == 0
151+
152+
153+
def test_read_sync_adjust_from_header_path():
154+
positions = Positions.read_sync_adjust_from_header_path(
155+
"tests/owlsbtlpos20210819_165120_2011TA.hdr", kmm2=True
156+
)
157+
assert len(positions.dataframe) > 0
158+
159+
160+
def test_read_sync_adjust_from_measurement_name():
161+
positions = Positions.read_sync_adjust_from_measurement_name(
162+
"20210819_165120_A", "tests"
163+
)
164+
assert len(positions.dataframe) > 0

kmm/positions/read_kmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import numpy as np
55
import pandas as pd
6-
from pydantic import validate_arguments
6+
from pydantic import validate_call
77

88

9-
@validate_arguments
9+
@validate_call
1010
def read_kmm(path: Path, replace_commas: bool = True):
1111
try:
1212
if replace_commas:

kmm/positions/read_kmm2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import pandas as pd
7-
from pydantic import validate_arguments
7+
from pydantic import validate_call
88

99
pattern = re.compile(r".+\[.+\]")
1010
pattern2 = re.compile(r"CMAST")
@@ -46,7 +46,7 @@
4646
)
4747

4848

49-
@validate_arguments
49+
@validate_call
5050
def read_kmm2(
5151
path: Path, raise_on_malformed_data: bool = True, replace_commas: bool = True
5252
):

kmm/positions/sync_frame_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import numpy as np
2-
from pydantic import validate_arguments
2+
from pydantic import validate_call
33

44
from kmm import CarDirection, PositionAdjustment
55
from kmm.header.header import Header
66
from kmm.positions.positions import Positions
77

88

9-
@validate_arguments(config=dict(arbitrary_types_allowed=True))
9+
@validate_call(config=dict(arbitrary_types_allowed=True))
1010
def sync_frame_index(
1111
positions: Positions,
1212
header: Header,

0 commit comments

Comments
 (0)