Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions kloppy/domain/services/transformers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,15 @@ def transform_dataset(
to_pitch_dimensions=to_pitch_dimensions,
to_orientation=to_orientation,
)
metadata = replace(
dataset.metadata,
pitch_dimensions=to_pitch_dimensions,
orientation=to_orientation,
)
if dataset.metadata.coordinate_system is not None:
dataset.metadata.coordinate_system.pitch_length = (
to_pitch_dimensions.pitch_length
)
dataset.metadata.coordinate_system.pitch_width = (
to_pitch_dimensions.pitch_width
)
dataset.metadata.pitch_dimensions = to_pitch_dimensions
dataset.metadata.orientation = to_orientation

elif to_coordinate_system is not None:
# Transform the coordinate system and optionally the orientation
Expand All @@ -414,12 +418,11 @@ def transform_dataset(
to_coordinate_system=to_coordinate_system,
to_orientation=to_orientation,
)
metadata = replace(
dataset.metadata,
coordinate_system=to_coordinate_system,
pitch_dimensions=to_coordinate_system.pitch_dimensions,
orientation=to_orientation,
dataset.metadata.coordinate_system = to_coordinate_system
dataset.metadata.pitch_dimensions = (
to_coordinate_system.pitch_dimensions
)
dataset.metadata.orientation = to_orientation

else:
# Only transform the orientation
Expand All @@ -442,10 +445,7 @@ def transform_dataset(
"Cannot transform orientation when the dataset doesn't "
"contain the pitch dimensions or a coordinate system"
)
metadata = replace(
dataset.metadata,
orientation=to_orientation,
)
dataset.metadata.orientation = to_orientation

if isinstance(dataset, TrackingDataset):
frames = [
Expand All @@ -454,7 +454,7 @@ def transform_dataset(
]

return TrackingDataset(
metadata=metadata,
metadata=dataset.metadata,
records=frames,
)
elif isinstance(dataset, EventDataset):
Expand All @@ -463,7 +463,7 @@ def transform_dataset(
]

return EventDataset(
metadata=metadata,
metadata=dataset.metadata,
records=events,
)
else:
Expand Down
14 changes: 6 additions & 8 deletions kloppy/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,10 @@ def test_transform_to_orientation(self):
# Transform to ACTION_EXECUTING_TEAM orientation
# this should be identical to BALL_OWNING_TEAM for tracking data
transform4 = transform3.transform(
to_orientation=Orientation.ACTION_EXECUTING_TEAM,
to_orientation=Orientation.BALL_OWNING_TEAM,
to_pitch_dimensions=to_pitch_dimensions,
)
assert (
transform4.metadata.orientation
== Orientation.ACTION_EXECUTING_TEAM
)
assert transform4.metadata.orientation == Orientation.BALL_OWNING_TEAM
assert transform4.frames[1].ball_coordinates == Point3D(x=0, y=1, z=1)
for frame_t3, frame_t4 in zip(transform3.frames, transform4.frames):
assert frame_t3.ball_coordinates == frame_t4.ball_coordinates
Expand Down Expand Up @@ -272,25 +269,26 @@ def test_transform_to_coordinate_system(self, base_dir):
transformed_dataset = dataset.transform(
to_coordinate_system=Provider.METRICA
)
transformerd_coordinate_system = MetricaCoordinateSystem(
transformed_coordinate_system = MetricaCoordinateSystem(
pitch_length=dataset.metadata.coordinate_system.pitch_length,
pitch_width=dataset.metadata.coordinate_system.pitch_width,
)

assert transformed_dataset.records[0].players_data[
player_home_1
].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583)

assert (
transformed_dataset.metadata.orientation
== dataset.metadata.orientation
)
assert (
transformed_dataset.metadata.coordinate_system
== transformerd_coordinate_system
== transformed_coordinate_system
)
assert (
transformed_dataset.metadata.pitch_dimensions
== transformerd_coordinate_system.pitch_dimensions
== transformed_coordinate_system.pitch_dimensions
)

def test_transform_event_data(self, base_dir):
Expand Down
Loading