Skip to content

Commit e4320ec

Browse files
UnravelSportsUnravelSports [JB]probberechts
authored
feat(tracab): decoupled Tracab dat / json from meta data file types (#364)
--------- Co-authored-by: UnravelSports [JB] <jors@unravelsports.com> Co-authored-by: Pieter Robberechts <pieter.robberechts@kuleuven.be>
1 parent 2af3859 commit e4320ec

File tree

18 files changed

+1043
-801
lines changed

18 files changed

+1043
-801
lines changed

kloppy/_providers/tracab.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
from typing import Optional, Union, Type
2-
1+
import warnings
2+
from typing import Optional
33

44
from kloppy.domain import TrackingDataset
5-
from kloppy.infra.serializers.tracking.tracab.tracab_dat import (
6-
TRACABDatDeserializer,
7-
)
8-
from kloppy.infra.serializers.tracking.tracab.tracab_json import (
9-
TRACABJSONDeserializer,
5+
from kloppy.infra.serializers.tracking.tracab.deserializer import (
6+
TRACABDeserializer,
107
TRACABInputs,
118
)
12-
from kloppy.io import FileLike, open_as_file, get_file_extension
9+
from kloppy.io import FileLike, open_as_file
1310

1411

1512
def load(
@@ -18,17 +15,67 @@ def load(
1815
sample_rate: Optional[float] = None,
1916
limit: Optional[int] = None,
2017
coordinates: Optional[str] = None,
21-
only_alive: Optional[bool] = True,
18+
only_alive: bool = True,
2219
file_format: Optional[str] = None,
2320
) -> TrackingDataset:
24-
if file_format == "dat":
25-
deserializer_class = TRACABDatDeserializer
26-
elif file_format == "json":
27-
deserializer_class = TRACABJSONDeserializer
28-
else:
29-
deserializer_class = identify_deserializer(meta_data, raw_data)
30-
31-
deserializer = deserializer_class(
21+
"""
22+
Load TRACAB tracking data.
23+
24+
Args:
25+
meta_data: A JSON or XML feed containing the meta data.
26+
raw_data: A JSON or dat feed containing the raw tracking data.
27+
sample_rate: Sample the data at a specific rate.
28+
limit: Limit the number of frames to load to the first `limit` frames.
29+
coordinates: The coordinate system to use.
30+
only_alive: Only include frames in which the game is not paused.
31+
file_format:
32+
Deprecated. The format will be inferred based on the file extensions.
33+
34+
Returns:
35+
The parsed tracking data.
36+
37+
Notes:
38+
Tracab distributes its metadata in various formats. Kloppy tries to
39+
infer automatically which format applies. Currently, kloppy supports
40+
the following formats:
41+
42+
- **Flat XML structure**:
43+
44+
<root>
45+
<GameID>13331</GameID>
46+
<CompetitionID>55</CompetitionID>
47+
...
48+
</root>
49+
50+
- **Hierarchical XML structure**:
51+
52+
<match iId="1" ...>
53+
<period iId="1" iStartFrame="1848508" iEndFrame="1916408"/>
54+
...
55+
</match>
56+
57+
- **JSON structure**:
58+
59+
{
60+
"GameID": 1,
61+
"CompetitionID": 1,
62+
"SeasonID": 2023,
63+
...
64+
}
65+
66+
If parsing fails for a supported format or you encounter an unsupported
67+
structure, please create an issue on the kloppy GitHub repository
68+
with a sample of the problematic data.
69+
"""
70+
# file format is deprecated
71+
if file_format is not None:
72+
warnings.warn(
73+
"file_format is deprecated. This is now automatically infered.",
74+
DeprecationWarning,
75+
stacklevel=2,
76+
)
77+
78+
deserializer = TRACABDeserializer(
3279
sample_rate=sample_rate,
3380
limit=limit,
3481
coordinate_system=coordinates,
@@ -40,22 +87,3 @@ def load(
4087
return deserializer.deserialize(
4188
inputs=TRACABInputs(meta_data=meta_data_fp, raw_data=raw_data_fp)
4289
)
43-
44-
45-
def identify_deserializer(
46-
meta_data: FileLike,
47-
raw_data: FileLike,
48-
) -> Union[Type[TRACABDatDeserializer], Type[TRACABJSONDeserializer]]:
49-
meta_data_extension = get_file_extension(meta_data)
50-
raw_data_extension = get_file_extension(raw_data)
51-
52-
if meta_data_extension == ".xml" and raw_data_extension == ".dat":
53-
deserializer = TRACABDatDeserializer
54-
elif meta_data_extension == ".json" and raw_data_extension == ".json":
55-
deserializer = TRACABJSONDeserializer
56-
else:
57-
raise ValueError(
58-
"Tracab file format could not be recognized, please specify"
59-
)
60-
61-
return deserializer

kloppy/infra/serializers/tracking/metrica_epts/reader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import re
2-
from typing import List, Iterator, IO
32
from datetime import timedelta
4-
3+
from typing import IO, Iterator, List
54

65
from .models import (
7-
PlayerChannel,
86
DataFormatSpecification,
97
EPTSMetadata,
8+
PlayerChannel,
109
Sensor,
1110
)
1211

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
from .common import TRACABInputs
2-
from .tracab_dat import TRACABDatDeserializer
3-
from .tracab_json import TRACABJSONDeserializer
1+
from .deserializer import TRACABDeserializer, TRACABInputs
42

53
__all__ = [
6-
"TRACABDatDeserializer",
7-
"TRACABJSONDeserializer",
4+
"TRACABDeserializer",
85
"TRACABInputs",
96
]

kloppy/infra/serializers/tracking/tracab/common.py

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import logging
2+
import warnings
3+
from typing import IO, NamedTuple, Optional, Union
4+
5+
from kloppy.domain import (
6+
AttackingDirection,
7+
DatasetFlag,
8+
Metadata,
9+
Orientation,
10+
Provider,
11+
TrackingDataset,
12+
attacking_direction_from_frame,
13+
)
14+
from kloppy.utils import performance_logging
15+
16+
from ..deserializer import TrackingDataDeserializer
17+
from .parsers import get_metadata_parser, get_raw_data_parser
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class TRACABInputs(NamedTuple):
23+
meta_data: IO[bytes]
24+
raw_data: IO[bytes]
25+
26+
27+
class TRACABDeserializer(TrackingDataDeserializer[TRACABInputs]):
28+
def __init__(
29+
self,
30+
limit: Optional[int] = None,
31+
sample_rate: Optional[float] = None,
32+
coordinate_system: Optional[Union[str, Provider]] = None,
33+
only_alive: bool = True,
34+
):
35+
super().__init__(limit, sample_rate, coordinate_system)
36+
self.only_alive = only_alive
37+
38+
@property
39+
def provider(self) -> Provider:
40+
return Provider.TRACAB
41+
42+
def deserialize(self, inputs: TRACABInputs) -> TrackingDataset:
43+
with performance_logging("Loading metadata", logger=logger):
44+
metadata_parser = get_metadata_parser(inputs.meta_data)
45+
(
46+
pitch_length,
47+
pitch_width,
48+
) = metadata_parser.extract_pitch_dimensions()
49+
teams = metadata_parser.extract_lineups()
50+
periods = metadata_parser.extract_periods()
51+
frame_rate = metadata_parser.extract_frame_rate()
52+
date = metadata_parser.extract_date()
53+
game_id = metadata_parser.extract_game_id()
54+
orientation = metadata_parser.extract_orientation()
55+
56+
transformer = self.get_transformer(
57+
pitch_length=pitch_length, pitch_width=pitch_width
58+
)
59+
60+
with performance_logging("Loading data", logger=logger):
61+
raw_data_parser = get_raw_data_parser(
62+
inputs.raw_data, periods, teams, frame_rate
63+
)
64+
frames = []
65+
for n, frame in enumerate(
66+
raw_data_parser.extract_frames(
67+
self.sample_rate, self.only_alive
68+
)
69+
):
70+
frame = transformer.transform_frame(frame)
71+
frames.append(frame)
72+
73+
if self.limit and n + 1 >= (self.limit / self.sample_rate):
74+
break
75+
76+
if orientation is None:
77+
try:
78+
first_frame = next(
79+
frame for frame in frames if frame.period.id == 1
80+
)
81+
orientation = (
82+
Orientation.HOME_AWAY
83+
if attacking_direction_from_frame(first_frame)
84+
== AttackingDirection.LTR
85+
else Orientation.AWAY_HOME
86+
)
87+
except StopIteration:
88+
warnings.warn(
89+
"Could not determine orientation of dataset, defaulting to NOT_SET"
90+
)
91+
orientation = Orientation.NOT_SET
92+
93+
metadata = Metadata(
94+
teams=list(teams),
95+
periods=periods,
96+
pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions,
97+
score=None,
98+
frame_rate=frame_rate,
99+
orientation=orientation,
100+
provider=Provider.TRACAB,
101+
flags=DatasetFlag.BALL_OWNING_TEAM | DatasetFlag.BALL_STATE,
102+
coordinate_system=transformer.get_to_coordinate_system(),
103+
date=date,
104+
game_id=game_id,
105+
)
106+
107+
return TrackingDataset(
108+
records=frames,
109+
metadata=metadata,
110+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import IO, List, Optional, Tuple
2+
3+
from lxml import objectify
4+
5+
from kloppy.domain import Period, Team
6+
7+
from .metadata.base import TracabMetadataParser
8+
from .metadata.flat_xml import TracabFlatXMLMetadataParser
9+
from .metadata.hierarchical_xml import TracabHierarchicalXMLMetadataParser
10+
from .metadata.json import TracabJSONMetadataParser
11+
from .raw_data.base import TracabDataParser
12+
from .raw_data.dat import TracabDatParser
13+
from .raw_data.json import TracabJSONParser
14+
15+
16+
def get_metadata_parser(
17+
feed: IO[bytes], feed_format: Optional[str] = None
18+
) -> TracabMetadataParser:
19+
# infer the data format if not provided
20+
if feed_format is None:
21+
if feed.read(1).decode("utf-8")[0] == "<":
22+
feed.seek(0)
23+
meta_data = objectify.fromstring(feed.read())
24+
if hasattr(meta_data, "match"):
25+
feed_format = "HIERARCHICAL_XML"
26+
else:
27+
feed_format = "FLAT_XML"
28+
else:
29+
feed_format = "JSON"
30+
feed.seek(0)
31+
32+
if feed_format.upper() == "JSON":
33+
return TracabJSONMetadataParser(feed)
34+
elif feed_format.upper() == "FLAT_XML":
35+
return TracabFlatXMLMetadataParser(feed)
36+
elif feed_format.upper() == "HIERARCHICAL_XML":
37+
return TracabHierarchicalXMLMetadataParser(feed)
38+
else:
39+
raise ValueError(f"Unknown metadata feed format {feed_format}")
40+
41+
42+
def get_raw_data_parser(
43+
feed: IO[bytes],
44+
periods: List[Period],
45+
teams: Tuple[Team, Team],
46+
frame_rate: int,
47+
feed_format: Optional[str] = None,
48+
) -> TracabDataParser:
49+
# infer the data format if not provided
50+
if feed_format is None:
51+
if feed.read(1).decode("utf-8")[0] == "{":
52+
feed_format = "JSON"
53+
else:
54+
feed_format = "DAT"
55+
feed.seek(0)
56+
57+
if feed_format.upper() == "DAT":
58+
return TracabDatParser(feed, periods, teams, frame_rate)
59+
elif feed_format.upper() == "JSON":
60+
return TracabJSONParser(feed, periods, teams, frame_rate)
61+
else:
62+
raise ValueError(f"Unknown raw data feed format {feed_format}")
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Base class for all parsers that can handle Tracab metadata files.
2+
3+
A parser reads a single metadata file and should extend the 'TracabMetadataParser'
4+
class to extract the data about periods, lineups, pitch dimensions, etc.
5+
"""
6+
7+
from abc import ABC, abstractmethod
8+
from datetime import datetime
9+
from typing import IO, List, Optional, Tuple
10+
11+
from kloppy.domain import Orientation, Period, Score, Team
12+
13+
14+
class TracabMetadataParser(ABC):
15+
"""Extract data from a tracab metadata file."""
16+
17+
def __init__(self, feed: IO[bytes]) -> None:
18+
"""Initialize the parser with the data stream.
19+
20+
Args:
21+
feed : The metadata of a game to parse.
22+
"""
23+
24+
@abstractmethod
25+
def extract_periods(self) -> List[Period]:
26+
"""Extract the periods of the game."""
27+
28+
def extract_score(self) -> Optional[Score]:
29+
"""Extract the game's score."""
30+
return None
31+
32+
def extract_date(self) -> Optional[datetime]:
33+
"""Extract the game's date."""
34+
return None
35+
36+
def extract_game_week(self) -> Optional[str]:
37+
"""Extract the game week."""
38+
return None
39+
40+
def extract_game_id(self) -> Optional[str]:
41+
"""Extract the game's id."""
42+
return None
43+
44+
@abstractmethod
45+
def extract_lineups(self) -> Tuple[Team, Team]:
46+
"""Extract the home and away team."""
47+
48+
@abstractmethod
49+
def extract_pitch_dimensions(self) -> Tuple[float, float]:
50+
"""Extract the pitch size as (length, width)."""
51+
52+
@abstractmethod
53+
def extract_frame_rate(self) -> int:
54+
"""Extract the tracking data's frame rate."""
55+
56+
def extract_orientation(self) -> Optional[Orientation]:
57+
"""Extract the orientation of the data."""
58+
return None

0 commit comments

Comments
 (0)