Skip to content

Commit 6e36746

Browse files
author
FelixAbrahamsson
committed
improve: groupby over track_section and km instead of total_meter
1 parent 71a3ef5 commit 6e36746

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

kmm/positions/wire_camera_positions.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from pydantic import validate_arguments
3+
import pandas as pd
34

45
from kmm import CarDirection
56
from kmm.positions.positions import Positions
@@ -24,31 +25,27 @@ def wire_camera_positions(positions: Positions, car_direction: CarDirection):
2425

2526

2627
def kmm_directions(df):
27-
28-
total_meter = (df["kilometer"] * 1000 + df["meter"]).values
29-
30-
track_section_changes = np.argwhere(np.concatenate([
31-
np.array([True]),
32-
df["track_section"].values[1:] != df["track_section"].values[:-1],
33-
])).squeeze(1).tolist() + [len(df)]
34-
35-
return np.concatenate([
36-
(
37-
np.ones(to_index - from_index, dtype=np.uint8)
38-
* kmm_direction(total_meter[from_index: to_index])
39-
)
40-
for from_index, to_index in zip(
41-
track_section_changes[:-1],
42-
track_section_changes[1:],
43-
)
44-
])
45-
46-
47-
def kmm_direction(total_meter):
48-
diffs = np.clip(total_meter[1:] - total_meter[:-1], -1, 1)
49-
if len(diffs) >= 10 and (diffs > 0).mean() < 0.9 and (diffs < 0).mean() < 0.9:
50-
raise ValueError("Unable to determine direction of kmm numbers.", diffs)
51-
return int(np.sign(diffs.sum()))
28+
records = list()
29+
for (track_section, kilometer), group in df.groupby(["track_section", "kilometer"]):
30+
diffs = np.sign(group["meter"].values[1:] - group["meter"].values[:-1])
31+
diffs = diffs[diffs != 0]
32+
if len(diffs) >= 10 and (diffs > 0).mean() < 0.9 and (diffs < 0).mean() < 0.9:
33+
raise ValueError(
34+
f"Inconsistent directions at track_section {track_section}, kilometer {kilometer}."
35+
)
36+
records.append(dict(
37+
track_section=track_section,
38+
kilometer=kilometer,
39+
direction=int(np.sign(diffs.sum())),
40+
))
41+
return df.merge(
42+
pd.DataFrame.from_records(
43+
records,
44+
columns=["track_section", "kilometer", "direction"],
45+
),
46+
on=["track_section", "kilometer"],
47+
how="left",
48+
)["direction"].values
5249

5350

5451
def test_camera_positions_kmm():

0 commit comments

Comments
 (0)