-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy path_plot.py
More file actions
48 lines (38 loc) · 1.55 KB
/
_plot.py
File metadata and controls
48 lines (38 loc) · 1.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import pandas as pd
def add_jitter(data, x_col, is_datetime=None, jitter_value=None):
"""
Adds jitter to duplicate x-values for better visibility.
Args:
data (DataFrame): The subset of the dataset to jitter.
x_col (str): Column name for x values.
is_datetime (bool): Whether the x-values are datetime objects. If None, will be detected.
jitter_value (float or timedelta): Jitter amount.
Returns:
DataFrame with an additional 'jittered_x' column.
"""
if data.empty:
return data
data = data.copy()
# Auto-detect datetime if not specified
if is_datetime is None:
is_datetime = pd.api.types.is_datetime64_any_dtype(data[x_col])
# Initialize jittered_x with original values
if is_datetime:
data["jittered_x"] = data[x_col]
else:
data["jittered_x"] = data[x_col].astype(float)
for x_val in data[x_col].unique():
mask = data[x_col] == x_val
count = mask.sum()
if count > 1:
# Create evenly spaced jitter values
if is_datetime:
jitters = [pd.Timedelta(seconds=float(j)) for j in np.linspace(-jitter_value, jitter_value, count)]
else:
jitters = np.linspace(-jitter_value, jitter_value, count)
# Apply jitter to each duplicate point
data.loc[mask, "jitter_index"] = range(count)
for i, j in enumerate(jitters):
data.loc[mask & (data["jitter_index"] == i), "jittered_x"] = x_val + j
return data