Skip to content

Commit 21e892e

Browse files
committed
Add outliers, label and fix rounding issue
1 parent 1285d6a commit 21e892e

File tree

3 files changed

+45
-18
lines changed

3 files changed

+45
-18
lines changed

python/e2b_code_interpreter/charts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class BoxAndWhiskerData:
165165
median: float
166166
third_quartile: float
167167
max: float
168+
outliers: List[float]
168169

169170
def __init__(self, **kwargs):
170171
self.label = kwargs["label"]
@@ -173,6 +174,7 @@ def __init__(self, **kwargs):
173174
self.median = kwargs["median"]
174175
self.third_quartile = kwargs["third_quartile"]
175176
self.max = kwargs["max"]
177+
self.outliers = kwargs.get("outliers") or []
176178

177179

178180
class BoxAndWhiskerChart(Chart2D):

python/tests/graphs/test_box_and_whiskers.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
# Create figure and axis
1616
fig, ax = plt.subplots(figsize=(10, 6))
1717
18-
# Plot box plot
19-
ax.boxplot(data.values(), labels=data.keys())
20-
2118
# Customize plot
2219
ax.set_title('Exam Scores Distribution')
2320
ax.set_xlabel('Class')
@@ -54,9 +51,10 @@ async def test_box_and_whiskers(async_sandbox: AsyncSandbox):
5451
bars = chart.elements
5552
assert len(bars) == 3
5653

57-
assert all(isinstance(bar.min, float) for bar in bars)
58-
assert all(isinstance(bar.first_quartile, float) for bar in bars)
59-
assert all(isinstance(bar.median, float) for bar in bars)
60-
assert all(isinstance(bar.third_quartile, float) for bar in bars)
61-
assert all(isinstance(bar.max, float) for bar in bars)
62-
assert all(isinstance(bar.label, str) for bar in bars)
54+
assert [bar.outliers for bar in bars] == [[], [76], []]
55+
assert [bar.min for bar in bars] == [78, 84, 75]
56+
assert [bar.first_quartile for bar in bars] == [85, 84.75, 79]
57+
assert [bar.median for bar in bars] == [88, 88, 82]
58+
assert [bar.third_quartile for bar in bars] == [90, 90.5, 86]
59+
assert [bar.max for bar in bars] == [92, 95, 88]
60+
assert [bar.label for bar in bars] == ["Class A", "Class B", "Class C"]

template/startup_scripts/0002_data.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import date
22
import enum
33
import re
4+
from decimal import Decimal, localcontext
45
from typing import Optional, List, Tuple, Literal, Any, Union, Sequence
56

67
import matplotlib
@@ -35,6 +36,18 @@ def _is_grid_line(line: Line2D) -> bool:
3536
return False
3637

3738

39+
def _dynamic_round(number):
40+
# Convert to Decimal for precise control
41+
decimal_number = Decimal(str(number))
42+
43+
# Dynamically determine precision based on magnitude
44+
precision = max(1, 16 - decimal_number.adjusted()) # 16 digits of precision
45+
46+
with localcontext() as ctx:
47+
ctx.prec = precision # Set the dynamic precision
48+
return +decimal_number # The + operator applies rounding
49+
50+
3851
class ChartType(str, enum.Enum):
3952
LINE = "line"
4053
SCATTER = "scatter"
@@ -291,6 +304,7 @@ class BoxAndWhiskerData(BaseModel):
291304
median: float
292305
third_quartile: float
293306
max: float
307+
outliers: List[float] = Field(default_factory=list)
294308

295309

296310
class BoxAndWhiskerChart(Chart2D):
@@ -301,20 +315,23 @@ class BoxAndWhiskerChart(Chart2D):
301315
def _extract_info(self, ax: Axes) -> None:
302316
super()._extract_info(ax)
303317

318+
labels = [item.get_text() for item in ax.get_xticklabels()]
319+
304320
boxes = []
305-
for box in ax.patches:
321+
for label, box in zip(labels, ax.patches):
306322
vertices = box.get_path().vertices
307-
x_vertices = vertices[:, 0]
308-
y_vertices = vertices[:, 1]
323+
x_vertices = [_dynamic_round(x) for x in vertices[:, 0]]
324+
y_vertices = [_dynamic_round(y) for y in vertices[:, 1]]
309325
x = min(x_vertices)
310326
y = min(y_vertices)
311327
boxes.append(
312328
{
313329
"x": x,
314330
"y": y,
315-
"label": box.get_label(),
316-
"width": round(max(x_vertices) - x, 4),
317-
"height": round(max(y_vertices) - y, 4),
331+
"label": label,
332+
"width": max(x_vertices) - x,
333+
"height": max(y_vertices) - y,
334+
"outliers": [],
318335
}
319336
)
320337

@@ -328,13 +345,21 @@ def _extract_info(self, ax: Axes) -> None:
328345
box["x"], box["y"] = box["y"], box["x"]
329346
box["width"], box["height"] = box["height"], box["width"]
330347

331-
for line in ax.lines:
332-
xdata = line.get_xdata()
333-
ydata = line.get_ydata()
348+
for i, line in enumerate(ax.lines):
349+
xdata = [_dynamic_round(x) for x in line.get_xdata()]
350+
ydata = [_dynamic_round(y) for y in line.get_ydata()]
334351

335352
if orientation == "vertical":
336353
xdata, ydata = ydata, xdata
337354

355+
if len(xdata) == 1:
356+
for box in boxes:
357+
if box["x"] <= xdata[0] <= box["x"] + box["width"]:
358+
break
359+
else:
360+
continue
361+
362+
box["outliers"].append(ydata[0])
338363
if len(ydata) != 2:
339364
continue
340365
for box in boxes:
@@ -344,6 +369,7 @@ def _extract_info(self, ax: Axes) -> None:
344369
continue
345370

346371
if (
372+
# Check if the line is inside the box, prevent floating point errors
347373
ydata[0] == ydata[1]
348374
and box["y"] <= ydata[0] <= box["y"] + box["height"]
349375
):
@@ -365,6 +391,7 @@ def _extract_info(self, ax: Axes) -> None:
365391
median=box["median"],
366392
third_quartile=box["y"] + box["height"],
367393
max=box["whisker_upper"],
394+
outliers=box["outliers"],
368395
)
369396
for box in boxes
370397
]

0 commit comments

Comments
 (0)