Skip to content

Commit eb45249

Browse files
committed
Move the graph data extraction to separate package
1 parent f47d9e6 commit eb45249

27 files changed

+1938
-470
lines changed

template/packages/data/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Extracting data for Code Interpreter SDK
2+
3+
This package is used to extract data in the Code Interpreter SDK from e.g. DataFrames and matplotlib plots.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .main import graph_figure_to_graph
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .base import GraphType, Graph
2+
from .bars import BarGraph, BoxAndWhiskerGraph
3+
from .pie import PieGraph
4+
from .planar import ScatterGraph, LineGraph
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Literal, List
2+
3+
from matplotlib.axes import Axes
4+
from pydantic import BaseModel, Field
5+
6+
from .base import Graph2D, GraphType
7+
8+
9+
class BarData(BaseModel):
10+
label: str
11+
group: str
12+
value: float
13+
14+
15+
class BarGraph(Graph2D):
16+
type: Literal[GraphType.BAR] = GraphType.BAR
17+
18+
elements: List[BarData] = Field(default_factory=list)
19+
20+
def _extract_info(self, ax: Axes) -> None:
21+
super()._extract_info(ax)
22+
for container in ax.containers:
23+
group_label = container.get_label()
24+
if group_label.startswith("_container"):
25+
number = int(group_label[10:])
26+
group_label = f"Group {number}"
27+
28+
heights = [rect.get_height() for rect in container]
29+
if all(height == heights[0] for height in heights):
30+
# vertical bars
31+
self._change_orientation()
32+
labels = [label.get_text() for label in ax.get_yticklabels()]
33+
values = [rect.get_width() for rect in container]
34+
else:
35+
# horizontal bars
36+
labels = [label.get_text() for label in ax.get_xticklabels()]
37+
values = heights
38+
for label, value in zip(labels, values):
39+
40+
bar = BarData(label=label, value=value, group=group_label)
41+
self.elements.append(bar)
42+
43+
44+
class BoxAndWhiskerData(BaseModel):
45+
label: str
46+
min: float
47+
first_quartile: float
48+
median: float
49+
third_quartile: float
50+
max: float
51+
52+
53+
class BoxAndWhiskerGraph(Graph2D):
54+
type: Literal[GraphType.BOX_AND_WHISKER] = GraphType.BOX_AND_WHISKER
55+
56+
elements: List[BoxAndWhiskerData] = Field(default_factory=list)
57+
58+
def _extract_info(self, ax: Axes) -> None:
59+
super()._extract_info(ax)
60+
61+
boxes = []
62+
for box in ax.patches:
63+
vertices = box.get_path().vertices
64+
x_vertices = vertices[:, 0]
65+
y_vertices = vertices[:, 1]
66+
x = min(x_vertices)
67+
y = min(y_vertices)
68+
boxes.append(
69+
{
70+
"x": x,
71+
"y": y,
72+
"label": box.get_label(),
73+
"width": round(max(x_vertices) - x, 4),
74+
"height": round(max(y_vertices) - y, 4),
75+
}
76+
)
77+
78+
orientation = "horizontal"
79+
if all(box["height"] == boxes[0]["height"] for box in boxes):
80+
orientation = "vertical"
81+
82+
if orientation == "vertical":
83+
self._change_orientation()
84+
for box in boxes:
85+
box["x"], box["y"] = box["y"], box["x"]
86+
box["width"], box["height"] = box["height"], box["width"]
87+
88+
for line in ax.lines:
89+
xdata = line.get_xdata()
90+
ydata = line.get_ydata()
91+
92+
if orientation == "vertical":
93+
xdata, ydata = ydata, xdata
94+
95+
if len(ydata) != 2:
96+
continue
97+
for box in boxes:
98+
if box["x"] <= xdata[0] <= xdata[1] <= box["x"] + box["width"]:
99+
break
100+
else:
101+
continue
102+
103+
if (
104+
ydata[0] == ydata[1]
105+
and box["y"] <= ydata[0] <= box["y"] + box["height"]
106+
):
107+
box["median"] = ydata[0]
108+
continue
109+
110+
lower_value = min(ydata)
111+
upper_value = max(ydata)
112+
if upper_value == box["y"]:
113+
box["whisker_lower"] = lower_value
114+
elif lower_value == box["y"] + box["height"]:
115+
box["whisker_upper"] = upper_value
116+
117+
self.elements = [
118+
BoxAndWhiskerData(
119+
label=box["label"],
120+
min=box["whisker_lower"],
121+
first_quartile=box["y"],
122+
median=box["median"],
123+
third_quartile=box["y"] + box["height"],
124+
max=box["whisker_upper"],
125+
)
126+
for box in boxes
127+
]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import enum
2+
import re
3+
from typing import Optional, List, Any
4+
5+
from matplotlib.axes import Axes
6+
from pydantic import BaseModel, Field
7+
8+
9+
class GraphType(str, enum.Enum):
10+
LINE = "line"
11+
SCATTER = "scatter"
12+
BAR = "bar"
13+
PIE = "pie"
14+
BOX_AND_WHISKER = "box_and_whisker"
15+
SUPERGRAPH = "supergraph"
16+
UNKNOWN = "unknown"
17+
18+
19+
class Graph(BaseModel):
20+
type: GraphType
21+
title: Optional[str] = None
22+
23+
elements: List[Any] = Field(default_factory=list)
24+
25+
def __init__(self, ax: Optional[Axes] = None, **kwargs):
26+
super().__init__(**kwargs)
27+
if ax:
28+
self._extract_info(ax)
29+
30+
def _extract_info(self, ax: Axes) -> None:
31+
"""
32+
Function to extract information for Graph
33+
"""
34+
title = ax.get_title()
35+
if title == "":
36+
title = None
37+
38+
self.title = title
39+
40+
41+
class Graph2D(Graph):
42+
x_label: Optional[str] = None
43+
y_label: Optional[str] = None
44+
x_unit: Optional[str] = None
45+
y_unit: Optional[str] = None
46+
47+
def _extract_info(self, ax: Axes) -> None:
48+
"""
49+
Function to extract information for Graph2D
50+
"""
51+
super()._extract_info(ax)
52+
x_label = ax.get_xlabel()
53+
if x_label == "":
54+
x_label = None
55+
self.x_label = x_label
56+
57+
y_label = ax.get_ylabel()
58+
if y_label == "":
59+
y_label = None
60+
self.y_label = y_label
61+
62+
regex = r"\s\((.*?)\)|\[(.*?)\]"
63+
if self.x_label:
64+
match = re.search(regex, self.x_label)
65+
if match:
66+
self.x_unit = match.group(1) or match.group(2)
67+
68+
if self.y_label:
69+
match = re.search(regex, self.y_label)
70+
if match:
71+
self.y_unit = match.group(1) or match.group(2)
72+
73+
def _change_orientation(self):
74+
self.x_label, self.y_label = self.y_label, self.x_label
75+
self.x_unit, self.y_unit = self.y_unit, self.x_unit
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Literal, List
2+
3+
from matplotlib.axes import Axes
4+
from pydantic import BaseModel, Field
5+
6+
from .base import Graph, GraphType
7+
8+
9+
class PieData(BaseModel):
10+
label: str
11+
angle: float
12+
radius: float
13+
14+
15+
class PieGraph(Graph):
16+
type: Literal[GraphType.PIE] = GraphType.PIE
17+
18+
elements: List[PieData] = Field(default_factory=list)
19+
20+
def _extract_info(self, ax: Axes) -> None:
21+
super()._extract_info(ax)
22+
23+
for wedge in ax.patches:
24+
pie_data = PieData(
25+
label=wedge.get_label(),
26+
angle=abs(round(wedge.theta2 - wedge.theta1, 4)),
27+
radius=wedge.r,
28+
)
29+
30+
self.elements.append(pie_data)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from datetime import date
2+
from typing import List, Tuple, Union, Sequence, Any, Literal
3+
4+
import matplotlib
5+
import numpy
6+
from matplotlib.axes import Axes
7+
from matplotlib.dates import _SwitchableDateConverter
8+
from pydantic import BaseModel, field_validator, Field
9+
10+
from .base import Graph2D, GraphType
11+
from ..utils.filtering import is_grid_line
12+
13+
14+
class PointData(BaseModel):
15+
label: str
16+
points: List[Tuple[Union[str, float], Union[str, float]]]
17+
18+
@field_validator("points", mode="before")
19+
@classmethod
20+
def transform_points(
21+
cls, value
22+
) -> List[Tuple[Union[str, float], Union[str, float]]]:
23+
parsed_value = []
24+
for x, y in value:
25+
if isinstance(x, date):
26+
x = x.isoformat()
27+
if isinstance(x, numpy.datetime64):
28+
x = x.astype("datetime64[s]").astype(str)
29+
30+
if isinstance(y, date):
31+
y = y.isoformat()
32+
if isinstance(y, numpy.datetime64):
33+
y = y.astype("datetime64[s]").astype(str)
34+
35+
parsed_value.append((x, y))
36+
return parsed_value
37+
38+
39+
class PointGraph(Graph2D):
40+
x_ticks: List[Union[str, float]] = Field(default_factory=list)
41+
x_tick_labels: List[str] = Field(default_factory=list)
42+
x_scale: str = Field(default="linear")
43+
44+
y_ticks: List[Union[str, float]] = Field(default_factory=list)
45+
y_tick_labels: List[str] = Field(default_factory=list)
46+
y_scale: str = Field(default="linear")
47+
48+
elements: List[PointData] = Field(default_factory=list)
49+
50+
def _extract_info(self, ax: Axes) -> None:
51+
"""
52+
Function to extract information for PointGraph
53+
"""
54+
super()._extract_info(ax)
55+
56+
self.x_tick_labels = [label.get_text() for label in ax.get_xticklabels()]
57+
58+
x_ticks = ax.get_xticks()
59+
self.x_ticks = self._extract_ticks_info(ax.xaxis.converter, x_ticks)
60+
self.x_scale = self._detect_scale(
61+
ax.xaxis.converter, ax.get_xscale(), self.x_ticks, self.x_tick_labels
62+
)
63+
64+
self.y_tick_labels = [label.get_text() for label in ax.get_yticklabels()]
65+
self.y_ticks = self._extract_ticks_info(ax.yaxis.converter, ax.get_yticks())
66+
self.y_scale = self._detect_scale(
67+
ax.yaxis.converter, ax.get_yscale(), self.y_ticks, self.y_tick_labels
68+
)
69+
70+
@staticmethod
71+
def _detect_scale(converter, scale: str, ticks: Sequence, labels: Sequence) -> str:
72+
# If the converter is a date converter, it's a datetime scale
73+
if isinstance(converter, _SwitchableDateConverter):
74+
return "datetime"
75+
76+
# If the scale is not linear, it can't be categorical
77+
if scale != "linear":
78+
return scale
79+
80+
# If all the ticks are integers and are in order from 0 to n-1
81+
# and the labels aren't corresponding to the ticks, it's categorical
82+
for i, tick_and_label in enumerate(zip(ticks, labels)):
83+
tick, label = tick_and_label
84+
if isinstance(tick, (int, float)) and tick == i and str(i) != label:
85+
continue
86+
# Found a tick, which wouldn't be in a categorical scale
87+
return "linear"
88+
89+
return "categorical"
90+
91+
@staticmethod
92+
def _extract_ticks_info(converter: Any, ticks: Sequence) -> list:
93+
if isinstance(converter, _SwitchableDateConverter):
94+
return [matplotlib.dates.num2date(tick).isoformat() for tick in ticks]
95+
else:
96+
example_tick = ticks[0]
97+
98+
if isinstance(example_tick, (int, float)):
99+
return [float(tick) for tick in ticks]
100+
else:
101+
return list(ticks)
102+
103+
104+
class LineGraph(PointGraph):
105+
type: Literal[GraphType.LINE] = GraphType.LINE
106+
107+
def _extract_info(self, ax: Axes) -> None:
108+
super()._extract_info(ax)
109+
110+
for line in ax.get_lines():
111+
if is_grid_line(line):
112+
continue
113+
label = line.get_label()
114+
if label.startswith("_child"):
115+
number = int(label[6:])
116+
label = f"Line {number}"
117+
118+
points = [(x, y) for x, y in zip(line.get_xdata(), line.get_ydata())]
119+
line_data = PointData(label=label, points=points)
120+
self.elements.append(line_data)
121+
122+
123+
class ScatterGraph(PointGraph):
124+
type: Literal[GraphType.SCATTER] = GraphType.SCATTER
125+
126+
def _extract_info(self, ax: Axes) -> None:
127+
super()._extract_info(ax)
128+
129+
for collection in ax.collections:
130+
points = [(x, y) for x, y in collection.get_offsets()]
131+
scatter_data = PointData(label=collection.get_label(), points=points)
132+
self.elements.append(scatter_data)

0 commit comments

Comments
 (0)