Skip to content

Commit b340a60

Browse files
committed
🛠️ [WIP] First attempts at Graphs
xtl.common.labels:LabelFmt - New enum for determining how a `Label` will be formatted xtl.common.data - New type `TData` for defining any `Data*` container - Fixed a type annotation issue in `DataGrid2D.z` xtl.plots.base - New `GraphConfig` model for storing certain global options - New `GraphBase` model that should be inherited - New `Graph1D` model with simple implementation of a 1D plot
1 parent 0a6cc14 commit b340a60

File tree

5 files changed

+86
-4
lines changed

5 files changed

+86
-4
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ src/xtl/scripts
55
src/xtl/chemistry
66
src/xtl/chemistry2
77
src/xtl/crystallization2
8-
src/xtl/plots
98

109
### Temporary
1110
docs_/

src/xtl/common/data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22
from typing_extensions import Self
33

44
from numpydantic import NDArray, Shape
@@ -8,6 +8,9 @@
88
from xtl.common.labels import Label
99

1010

11+
TData = Union['Data0D', 'Data1D', 'Data2D', 'Data3D', 'DataGrid2D']
12+
13+
1114
class Data0D(BaseModel):
1215
"""
1316
A class to represent 0D data, i.e. one series of values.
@@ -86,7 +89,7 @@ class DataGrid2D(BaseModel):
8689

8790
x: Data0D
8891
y: Data0D
89-
z: NDArray[Shape['*', '*'], Number]
92+
z: NDArray[Shape['*, *'], Number]
9093

9194
@model_validator(mode='after')
9295
def check_array_shapes(self) -> Self:

src/xtl/common/labels.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
from enum import Enum
12
from typing import Optional
23

34
from pydantic import BaseModel, ConfigDict, Field
45

56

7+
class LabelFmt(Enum):
8+
VALUE = 'value'
9+
DESC = 'desc'
10+
REPR = 'repr'
11+
LATEX = 'latex'
12+
13+
614
class Label(BaseModel):
715
__pydantic_config__ = ConfigDict(validate_assignment=True, extra='forbid')
816

917
value: str
1018
desc: Optional[str] = Field(default=None, repr=False)
11-
repr: Optional[str] = Field(default=None, repr=False)
19+
repr: Optional[str] = Field(default=None, repr=False) # rename: ascii
1220
latex: Optional[str] = Field(default=None, repr=False)

src/xtl/plots/__init__.py

Whitespace-only changes.

src/xtl/plots/base/base.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any, Optional
2+
3+
from matplotlib.axes import Axes
4+
from matplotlib.figure import Figure
5+
import matplotlib.pyplot as plt
6+
from pydantic import BaseModel, ConfigDict, Field, validate_call
7+
8+
from xtl.common.data import TData, Data0D, Data1D
9+
from xtl.common.labels import Label, LabelFmt
10+
11+
12+
class GraphConfig(BaseModel):
13+
model_config = ConfigDict(validate_assignment=True, extra='forbid')
14+
15+
label_fmt: LabelFmt = LabelFmt.VALUE
16+
17+
18+
class GraphBase(BaseModel):
19+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True,
20+
extra='forbid')
21+
22+
data: TData
23+
config: GraphConfig = Field(default_factory=GraphConfig)
24+
25+
# Fix: Should axis & figure be excluded from the model serialization?
26+
axis: Optional[Axes] = Field(default_factory=lambda: plt.gca(), alias='ax')
27+
figure: Optional[Figure] = Field(default_factory=lambda: plt.gcf(), alias='fig')
28+
29+
title: Optional[str] = None
30+
31+
def get_label(self, data: Data0D) -> Optional[str]:
32+
value = getattr(data.label, self.config.label_repr.value)
33+
units = getattr(data.units, self.config.label_repr.value)
34+
35+
v = value or data.label.value
36+
u = units or data.units.value
37+
if self.config.label_repr in [LabelFmt.VALUE, LabelFmt.LATEX]:
38+
if u:
39+
return f'{v} ({u})'
40+
return v
41+
elif self.config.label_repr == LabelFmt.REPR:
42+
if u:
43+
return f'{v} [{u}]'
44+
return v
45+
return None
46+
47+
def plot(self):
48+
raise NotImplementedError
49+
50+
51+
class Graph1D(GraphBase):
52+
data: Data1D
53+
54+
@property
55+
def xlabel(self) -> str:
56+
return self.get_label(self.data.x)
57+
58+
@property
59+
def ylabel(self) -> str:
60+
return self.get_label(self.data.y)
61+
62+
@validate_call
63+
def plot(self, xlabel: Optional[str] = None, ylabel: Optional[str] = None):
64+
if xlabel:
65+
self.data.x.label = Label(value=xlabel)
66+
if ylabel:
67+
self.data.y.label = Label(value=ylabel)
68+
69+
self.axis.plot(self.data.x.data, self.data.y.data)
70+
71+
self.axis.set_xlabel(self.xlabel)
72+
self.axis.set_ylabel(self.ylabel)

0 commit comments

Comments
 (0)