Skip to content

Commit a24ed48

Browse files
Merge pull request #42 from MannLabs/shape-attributes-1
Custom shapes attributes
2 parents 0787c86 + 3377eb9 commit a24ed48

File tree

2 files changed

+131
-21
lines changed

2 files changed

+131
-21
lines changed

src/lmd/lib.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from skimage.segmentation import find_boundaries
3232

3333
from pathlib import Path
34-
from typing import Callable, Optional, Union, Iterable
34+
from typing import Callable, Optional, Union, Iterable, Any
3535

3636
import geopandas as gpd
3737
import matplotlib.pyplot as plt
@@ -263,7 +263,7 @@ def add_shape(self, shape: Shape):
263263
TypeError("Provided shape is not of type Shape")
264264

265265
def new_shape(
266-
self, points: np.ndarray, well: Optional[str] = None, name: Optional[str] = None
266+
self, points: np.ndarray, well: Optional[str] = None, name: Optional[str] = None, **custom_attributes
267267
):
268268
"""Directly create a new Shape in the current collection.
269269
@@ -273,13 +273,17 @@ def new_shape(
273273
well: Well in which to sort the shape after cutting. For example A1, A2 or B3.
274274
275275
name: Name of the shape.
276-
"""
277276
277+
custom_attributes: Custom shape metadata that can be added as additional xml-element to the shape.
278+
All values are converted to strings.
279+
280+
"""
278281
to_add = Shape(
279282
points,
280283
well=well,
281284
name=name,
282285
orientation_transform=self.orientation_transform,
286+
**custom_attributes
283287
)
284288
self.add_shape(to_add)
285289

@@ -319,22 +323,29 @@ def to_geopandas(self, *attrs: str) -> gpd.GeoDataFrame:
319323
.. code-block:: python
320324
# Generate collection
321325
collection = pylmd.Collection()
322-
shape = pylmd.Shape(np.array([[ 0, 0], [ 0, -1], [ 1, 0], [ 0, 0]]), well="A1", name="Shape_1", orientation_transform=None)
326+
shape = pylmd.Shape(
327+
np.array([[ 0, 0], [ 0, -1], [ 1, 0], [ 0, 0]]),
328+
well="A1",
329+
name="Shape_1",
330+
metadata1="A",
331+
metadata2="B",
332+
orientation_transform=None
333+
)
323334
collection.add_shape(shape)
324335
325336
# Get geopandas object
326337
collection.to_geopandas()
327338
> geometry
328339
0 POLYGON ((0 0, 0 -1, 1 0, 0 0))
329340
330-
collection.to_geopandas("well", "name")
331-
> well name geometry
332-
0 A1 Shape_1 POLYGON ((0 0, 0 -1, 1 0, 0 0))
341+
collection.to_geopandas("well", "name", "metadata1", "metadata2")
342+
> well name metadata1 metadata2 geometry
343+
0 A1 Shape_1 A B POLYGON ((0 0, 0 -1, 1 0, 0 0))
333344
"""
334345
metadata = (
335346
pd.DataFrame(
336347
[
337-
[shape.__getattribute__(att) for att in attrs]
348+
[shape.get_shape_annotation(att) for att in attrs]
338349
for shape in self.shapes
339350
],
340351
columns=attrs,
@@ -393,6 +404,7 @@ def load_geopandas(
393404
well_column: Optional[str] = None,
394405
calibration_points: Optional[np.ndarray] = None,
395406
global_coordinates: Optional[int] = None,
407+
custom_attribute_columns: str | list[str] | None = None,
396408
) -> None:
397409
"""Create collection from a geopandas dataframe
398410
@@ -402,6 +414,8 @@ def load_geopandas(
402414
well_column (str, optional): Column storing of well id as additional metadata
403415
calibration_points (np.ndarray, optional): Calibration points of collection
404416
global_coordinates (int, optional): Number of global coordinates
417+
custom_attribute_columns Custom shape metadata that will be added as additional xml-element to the shape.
418+
Can be column name, list of column names or None
405419
406420
Example:
407421
@@ -433,11 +447,17 @@ def load_geopandas(
433447
if global_coordinates is not None:
434448
self.global_coordinates = global_coordinates
435449

450+
if custom_attribute_columns is None:
451+
custom_attribute_columns = []
452+
if isinstance(custom_attribute_columns, str):
453+
custom_attribute_columns = [custom_attribute_columns]
454+
436455
self.shapes = [
437456
Shape(
438457
points=np.array(row[geometry_column].exterior.coords),
439458
name=row[name_column] if name_column is not None else None,
440459
well=row[well_column] if well_column is not None else None,
460+
**{att: row[att] for att in custom_attribute_columns}
441461
)
442462
for _, row in gdf.iterrows()
443463
]
@@ -542,6 +562,7 @@ def __init__(
542562
well: Optional[str] = None,
543563
name: Optional[str] = None,
544564
orientation_transform=None,
565+
**custom_attributes: dict[str, str]
545566
):
546567
"""Class for creating a single shape.
547568
@@ -551,8 +572,10 @@ def __init__(
551572
well: Well in which to sort the shape after cutting. For example A1, A2 or B3.
552573
553574
name: Name of the shape.
554-
"""
555575
576+
custom_attributes: Custom shape metadata that will be added as additional xml-element to the shape
577+
Values be implicitly converted to strings.
578+
"""
556579
# Orientation transform of shapes
557580
self.orientation_transform: Optional[np.ndarray] = orientation_transform
558581

@@ -569,6 +592,8 @@ def __init__(
569592
self.name: Optional[str] = name
570593
self.well: Optional[str] = well
571594

595+
self.custom_attributes = custom_attributes
596+
572597
def from_xml(self, root):
573598
"""Load a shape from an XML shape node. Used internally for reading LMD generated XML files.
574599
@@ -599,10 +624,14 @@ def from_xml(self, root):
599624
points[point_id, 1] = int(child.text)
600625
elif child.tag == "CapID":
601626
self.well = str(child.text)
627+
else:
628+
if child.tag in self.custom_attributes:
629+
warnings.warn(f"Shape attribute {child.tag} already found in shape, overwrite", stacklevel=1)
630+
self.custom_attributes[child.tag] = child.text
602631

603632
self.points = np.array(points)
604633

605-
def to_xml(self, id: int, orientation_transform: np.ndarray, scale: int):
634+
def to_xml(self, id: int, orientation_transform: np.ndarray, scale: int, *, write_custom_attributes: bool = True):
606635
"""Generate XML shape node needed internally for export.
607636
608637
Args:
@@ -612,6 +641,8 @@ def to_xml(self, id: int, orientation_transform: np.ndarray, scale: int):
612641
613642
scale (int): Scalling factor used to enable higher decimal precision.
614643
644+
write_custom_attributes: Write custom attributes to xml file
645+
615646
Note:
616647
If the Shape has a custom orientation_transform defined, the custom orientation_transform is applied at this point. If not, the oritenation_transform passed by the parent Collection is used. This highlights an important difference between the Shape and Collection class. The Collection will always has an orientation transform defined and will use `np.eye(2)` by default. The Shape object can have a orientation_transform but can also be set to `None` to use the Collection value.
617648
@@ -633,6 +664,12 @@ def to_xml(self, id: int, orientation_transform: np.ndarray, scale: int):
633664
cap_id = ET.SubElement(shape, "CapID")
634665
cap_id.text = self.well
635666

667+
if write_custom_attributes:
668+
for attribute_name, attribute_value in self.custom_attributes.items():
669+
custom_attribute = ET.SubElement(shape, attribute_name)
670+
# xml only accepts string values
671+
custom_attribute.text = str(attribute_value)
672+
636673
# write points
637674
for i, point in enumerate(transformed_points):
638675
id = i + 1
@@ -643,6 +680,28 @@ def to_xml(self, id: int, orientation_transform: np.ndarray, scale: int):
643680
y.text = "{}".format(np.floor(point[1]).astype(int))
644681

645682
return shape
683+
684+
def get_shape_annotation(self, name: str) -> Any | None:
685+
"""Retrieve the value of an attribute from either instance attributes
686+
or custom attributes by name.
687+
688+
Searches for the attribute by name in the 1) instance attributes
689+
2) custom attributes, or 3) issues a warning and returns None
690+
691+
Args:
692+
name (str): The name of the attribute to retrieve.
693+
694+
Returns:
695+
Any | None: The value of the attribute if found, otherwise None.
696+
"""
697+
if name in self.__dict__:
698+
return getattr(self, name)
699+
elif name in self.custom_attributes:
700+
return self.custom_attributes.get(name)
701+
else:
702+
warnings.warn(f"Attribute {name} not found in shape attributes. Returning None.")
703+
return None
704+
646705

647706
def to_shapely(self):
648707
return shapely.Polygon(self.points)

src/lmd/lmd_test.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,29 @@
88
import geopandas as gpd
99
import shapely
1010
from lxml import etree as ET
11+
import pytest
1112

1213
def test_collection():
1314
calibration = np.array([[0, 0], [0, 100], [50, 50]])
1415
my_first_collection = Collection(calibration_points = calibration)
1516

17+
1618
def test_shape():
1719
rectangle_coordinates = np.array([[10,10], [40,10], [40,40], [10,40], [10,10]])
1820
rectangle = Shape(rectangle_coordinates)
1921

22+
2023
def test_shape_from_xml():
24+
"""Read a minimal xml representation of a cell shape and associated metadata"""
2125
# Define shape in xml
2226
shape_xml = """
2327
<Shape_1>
2428
<PointCount>3</PointCount>
2529
<CapID>A1</CapID>
2630
<TEST>this is a test</TEST>
31+
<test2>1</test2>
32+
<test3>3.1415</test3>
33+
2734
<X_1>0</X_1>
2835
<Y_1>-0</Y_1>
2936
<X_2>0</X_2>
@@ -41,6 +48,10 @@ def test_shape_from_xml():
4148
shape.from_xml(shape_xml)
4249
assert (shape.points == np.array([[ 0, 0], [ 0, -1], [ 1, 0]])).all()
4350
assert shape.well == "A1"
51+
assert shape.custom_attributes["TEST"] == "this is a test"
52+
assert shape.custom_attributes["test2"] == "1"
53+
assert shape.custom_attributes["test3"] == "3.1415"
54+
4455

4556
def test_plotting():
4657
calibration = np.array([[0, 0], [0, 100], [50, 50]])
@@ -60,31 +71,68 @@ def test_plotting():
6071

6172
my_first_collection.plot(calibration = True)
6273

63-
def test_collection_load_geopandas():
64-
gdf = gpd.GeoDataFrame(
65-
data={"well": ["A1"], "name": "my_shape"},
74+
75+
@pytest.fixture
76+
def geopandas_collection():
77+
"""Geopandas shape collection with both controlled (name, well) and custom metadata"""
78+
return gpd.GeoDataFrame(
79+
data={"well": ["A1"], "name": "my_shape", "string_attribute": "a"},
6680
geometry=[shapely.Polygon([[0, 0], [0, 1], [1, 0], [0, 0]])]
6781
)
6882

83+
@pytest.mark.parametrize(
84+
("well_column", "name_column", "custom_attributes"),
85+
[
86+
("well", None, None),
87+
(None, "well", None),
88+
(None, None, "string_attribute"),
89+
("well", "name", None),
90+
("well", "name", "string_attribute"),
91+
92+
]
93+
)
94+
def test_collection_load_geopandas(
95+
geopandas_collection: gpd.GeoDataFrame,
96+
well_column: str,
97+
name_column: str,
98+
custom_attributes: list[str]
99+
) -> None:
69100

70101
# Export well metadata
71102
c = Collection(calibration_points=np.array([[-1, -1], [1, 1], [0, 1]]))
72103
calibration_points_old = c.calibration_points
73-
c.load_geopandas(gdf, well_column="well", name_column="name")
74-
assert c.to_geopandas("well", "name").equals(gdf)
104+
c.load_geopandas(
105+
geopandas_collection,
106+
well_column=well_column,
107+
name_column=name_column,
108+
custom_attribute_columns=custom_attributes
109+
)
110+
111+
all_columns = [
112+
col for col in (well_column, custom_attributes) if col is not None
113+
]
114+
115+
assert c.to_geopandas(*all_columns).equals(geopandas_collection[[*all_columns, "geometry"]])
75116
assert (c.calibration_points == calibration_points_old).all()
76117

77118
# Overwrite calibration points
78119
c = Collection(calibration_points=np.array([[-1, -1], [1, 1], [0, 1]]))
79120
calibration_points_new = np.array([[0, 0], [100, 0], [0, 100]])
80-
c.load_geopandas(gdf, well_column="well", name_column="name", calibration_points=calibration_points_new)
81-
assert c.to_geopandas("well", "name").equals(gdf)
121+
122+
c.load_geopandas(
123+
geopandas_collection,
124+
calibration_points=calibration_points_new,
125+
well_column=well_column,
126+
name_column=name_column,
127+
custom_attribute_columns=custom_attributes
128+
)
129+
assert c.to_geopandas(*all_columns).equals(geopandas_collection[[*all_columns, "geometry"]])
82130
assert (c.calibration_points == calibration_points_new).all()
83131

84132
# Do not export well metadata
85133
c = Collection(calibration_points=np.array([[-1, -1], [1, 1], [0, 1]]))
86-
c.load_geopandas(gdf)
87-
assert c.to_geopandas().equals(gdf.drop(columns=["well", "name"]))
134+
c.load_geopandas(geopandas_collection)
135+
assert c.to_geopandas().equals(geopandas_collection[["geometry"]])
88136

89137
def test_collection_save():
90138
calibration = np.array([[0, 0], [0, 100], [50, 50]])
@@ -98,7 +146,8 @@ def test_collection_save():
98146
my_first_collection.add_shape(rectangle)
99147

100148
my_first_collection.save("first_collection.xml")
101-
149+
150+
102151
def test_tools_square():
103152
calibration = np.array([[0, 0], [0, 100], [50, 50]])
104153
my_first_collection = Collection(calibration_points = calibration)
@@ -108,7 +157,8 @@ def test_tools_square():
108157

109158
my_square = tools.rectangle(10, 10, offset=(30,30))
110159
my_first_collection.add_shape(my_square)
111-
160+
161+
112162
def test_glyphs():
113163
calibration = np.array([[0, 0], [0, 100], [50, 50]])
114164
my_first_collection = Collection(calibration_points = calibration)
@@ -135,6 +185,7 @@ def test_text():
135185
identifier_3 = tools.text('0123456789-_ABCDEFGHI', offset=np.array([60, 40]), rotation = -np.pi/4)
136186
my_first_collection.join(identifier_3)
137187

188+
138189
def test_segmentation_loader():
139190

140191
package_base_path = pathlib.Path(__file__).parent.parent.parent.resolve().absolute()

0 commit comments

Comments
 (0)