Skip to content

Commit 566fcf9

Browse files
authored
Merge pull request #62 from atomscale-ai/enhancement/add_similarity_traj_provider
Similarity trajectory provider
2 parents 548c16f + ef7da59 commit 566fcf9

File tree

9 files changed

+276
-25
lines changed

9 files changed

+276
-25
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ lint.extend-ignore = [
125125
"B028", # No explicit stacklevel
126126
"EM101", # Exception must not use a string literal
127127
"EM102", # Exception must not use an f-string literal
128-
"PD901", # Avoid using the generic variable name `df` for DataFrames
129128
]
130129
lint.typing-modules = ["mypackage._compat.typing"]
131130
src = ["src"]

src/atomscale/results/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .raman import RamanResult
66
from .rheed_image import RHEEDImageCollection, RHEEDImageResult, _get_rheed_image_result
77
from .rheed_video import RHEEDVideoResult
8+
from .similarity_trajectory import SimilarityTrajectoryResult
89
from .unknown import UnknownResult
910
from .xps import XPSResult
1011

@@ -18,6 +19,7 @@
1819
"RHEEDImageResult",
1920
"RHEEDVideoResult",
2021
"RamanResult",
22+
"SimilarityTrajectoryResult",
2123
"UnknownResult",
2224
"XPSResult",
2325
"_get_rheed_image_result",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from uuid import UUID
5+
6+
from monty.json import MSONable
7+
from pandas import DataFrame
8+
9+
10+
class SimilarityTrajectoryResult(MSONable):
11+
def __init__(
12+
self,
13+
source_id: UUID | str,
14+
workflow: str,
15+
window_span: float,
16+
timeseries_data: DataFrame,
17+
source_data_ids: Sequence[UUID | str] | None = None,
18+
):
19+
"""Similarity trajectory result
20+
21+
Args:
22+
source_id (UUID | str): Source ID for the similarity trajectory query.
23+
workflow (str): Workflow name used for the similarity analysis.
24+
window_span (float): Window span parameter used for the trajectory.
25+
timeseries_data (DataFrame): Pandas DataFrame with similarity trajectory data.
26+
source_data_ids (Sequence[UUID | str] | None): Sequence of source data IDs included in the trajectory.
27+
"""
28+
self.source_id = source_id
29+
self.workflow = workflow
30+
self.window_span = window_span
31+
self.timeseries_data = timeseries_data
32+
self.source_data_ids: list[UUID | str] = (
33+
list(source_data_ids) if source_data_ids else []
34+
)

src/atomscale/timeseries/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from .provider import TimeseriesProvider
55
from .registry import get_provider
66
from .rheed import RHEEDProvider
7+
from .similarity import SimilarityTrajectoryProvider
78

89
__all__ = [
910
"MetrologyProvider",
1011
"OpticalProvider",
1112
"RHEEDProvider",
13+
"SimilarityTrajectoryProvider",
1214
"TimeseriesProvider",
1315
"align_timeseries",
1416
"get_provider",

src/atomscale/timeseries/align.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,26 @@ def _extract_timeseries(result):
7171
"""Return (data_id, domain, df_with_timeindex) or None for non-timeseries."""
7272
if isinstance(result, RHEEDVideoResult):
7373
domain = "rheed"
74-
df = result.timeseries_data
74+
timeseries = result.timeseries_data
7575
elif isinstance(result, OpticalResult):
7676
domain = "optical"
77-
df = result.timeseries_data
77+
timeseries = result.timeseries_data
7878
elif isinstance(result, MetrologyResult):
7979
domain = "metrology"
80-
df = result.timeseries_data
80+
timeseries = result.timeseries_data
8181
else:
8282
return None
8383

84-
if df is None or df.empty:
84+
if timeseries is None or timeseries.empty:
8585
return None
8686

8787
# Build time index: prefer absolute epochs; fall back to upload_datetime + relative offsets.
8888
upload_dt = getattr(result, "upload_datetime", None)
8989

90-
time_index = _infer_absolute_time(df)
90+
time_index = _infer_absolute_time(timeseries)
9191
if time_index is None and upload_dt is not None:
9292
base = pd.to_datetime(upload_dt, utc=True, errors="coerce")
93-
rel = _infer_relative_time(df)
93+
rel = _infer_relative_time(timeseries)
9494
if base is not pd.NaT and rel is not None:
9595
time_index = base + rel
9696

@@ -101,7 +101,7 @@ def _extract_timeseries(result):
101101
if not valid_mask.any():
102102
return None
103103

104-
indexed = df.loc[valid_mask].copy(deep=False)
104+
indexed = timeseries.loc[valid_mask].copy(deep=False)
105105
indexed.index = pd.Index(time_index[valid_mask], name="time")
106106
indexed = indexed.sort_index()
107107

@@ -173,11 +173,11 @@ def align_timeseries(
173173
if not extracted:
174174
continue
175175

176-
data_id, domain, df = extracted
177-
df = df.copy(deep=False)
178-
df.columns = pd.MultiIndex.from_product([[data_id], [domain], df.columns])
179-
frames.append(df)
180-
indices.append(df.index)
176+
data_id, domain, frame = extracted
177+
frame = frame.copy(deep=False)
178+
frame.columns = pd.MultiIndex.from_product([[data_id], [domain], frame.columns])
179+
frames.append(frame)
180+
indices.append(frame.index)
181181

182182
if not frames:
183183
return pd.DataFrame()
@@ -211,28 +211,28 @@ def align_timeseries(
211211

212212
# Merge compatible metrics across items: if multiple columns share (domain, metric)
213213
# and never conflict where they overlap, collapse into (shared, domain, metric).
214-
def _merge_compatible_metrics(df: pd.DataFrame) -> pd.DataFrame:
215-
if not isinstance(df.columns, pd.MultiIndex):
216-
return df
217-
domains = df.columns.get_level_values(1)
218-
metrics = df.columns.get_level_values(2)
214+
def _merge_compatible_metrics(data: pd.DataFrame) -> pd.DataFrame:
215+
if not isinstance(data.columns, pd.MultiIndex):
216+
return data
217+
domains = data.columns.get_level_values(1)
218+
metrics = data.columns.get_level_values(2)
219219
new_cols: dict = {}
220220
drop_cols: list = []
221221

222222
for domain in domains.unique():
223223
for metric in metrics.unique():
224224
cols = [
225225
c
226-
for c in df.columns
226+
for c in data.columns
227227
if c[1] == domain and c[2] == metric and c[0] != "shared"
228228
]
229229
if len(cols) <= 1:
230230
continue
231231

232-
merged = df[cols[0]]
232+
merged = data[cols[0]]
233233
conflict = False
234234
for c in cols[1:]:
235-
other = df[c]
235+
other = data[c]
236236
overlap_mask = merged.notna() & other.notna()
237237
if (merged[overlap_mask] != other[overlap_mask]).any():
238238
conflict = True
@@ -247,10 +247,10 @@ def _merge_compatible_metrics(df: pd.DataFrame) -> pd.DataFrame:
247247
drop_cols.extend(cols)
248248

249249
if new_cols:
250-
df = df.drop(columns=drop_cols)
250+
data = data.drop(columns=drop_cols)
251251
for col, series in new_cols.items():
252-
df[col] = series
253-
df = df.sort_index(axis=1)
254-
return df
252+
data[col] = series
253+
data = data.sort_index(axis=1)
254+
return data
255255

256256
return _merge_compatible_metrics(aligned)

src/atomscale/timeseries/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from .optical import OpticalProvider
55
from .provider import TimeseriesProvider
66
from .rheed import RHEEDProvider
7+
from .similarity import SimilarityTrajectoryProvider
78

89
_PROVIDER_CLASSES: dict[str, type[TimeseriesProvider]] = {
910
RHEEDProvider.TYPE: RHEEDProvider,
1011
OpticalProvider.TYPE: OpticalProvider,
1112
MetrologyProvider.TYPE: MetrologyProvider,
13+
SimilarityTrajectoryProvider.TYPE: SimilarityTrajectoryProvider,
1214
}
1315

1416

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Mapping, Sequence
4+
from typing import Any
5+
from uuid import UUID
6+
7+
from pandas import DataFrame, concat
8+
9+
from atomscale.core import BaseClient
10+
from atomscale.results.similarity_trajectory import SimilarityTrajectoryResult
11+
from atomscale.timeseries.provider import TimeseriesProvider
12+
13+
14+
class SimilarityTrajectoryProvider(TimeseriesProvider[SimilarityTrajectoryResult]):
15+
TYPE = "similarity_trajectory"
16+
17+
RENAME_MAP: Mapping[str, str] = {
18+
"reference_id": "Reference ID",
19+
"reference_item_name": "Reference Name",
20+
"real_time_seconds": "Time",
21+
"similarity_values": "Similarity",
22+
"unix_times": "UNIX Timestamp",
23+
"is_active": "Active",
24+
"averaged_count": "Averaged Count",
25+
}
26+
INDEX_COLS: Sequence[str] = ["Reference ID", "Time"]
27+
28+
def fetch_raw(self, client: BaseClient, data_id: str, **kwargs: Any) -> Any:
29+
"""Fetch similarity trajectory data from the API.
30+
31+
Args:
32+
client: The API client.
33+
data_id: The source ID for the similarity query.
34+
**kwargs: Must include 'workflow' (required). Optional parameters:
35+
window_span, reference_ids, softmax_mode, reference_n_values.
36+
37+
Returns:
38+
Raw API response payload.
39+
40+
Raises:
41+
KeyError: If 'workflow' is not provided in kwargs.
42+
"""
43+
workflow = kwargs.pop("workflow")
44+
return client._get(
45+
sub_url=f"similarity/{workflow}/{data_id}/trajectory/",
46+
params=kwargs,
47+
)
48+
49+
def to_dataframe(self, raw: Any) -> DataFrame:
50+
if not raw:
51+
return DataFrame(None)
52+
53+
trajectories = raw.get("trajectories", [])
54+
if not trajectories:
55+
return DataFrame(None)
56+
57+
frames: list[DataFrame] = []
58+
for traj in trajectories:
59+
ref_id = traj.get("reference_id")
60+
ref_name = traj.get("reference_item_name")
61+
similarity_values = traj.get("similarity_values", [])
62+
real_time_seconds = traj.get("real_time_seconds", [])
63+
unix_times = traj.get("unix_times", [])
64+
is_active = traj.get("is_active")
65+
averaged_count = traj.get("averaged_count")
66+
67+
if not similarity_values:
68+
continue
69+
70+
# Build dataframe from columnar data
71+
traj_df = DataFrame(
72+
{
73+
"reference_id": ref_id,
74+
"reference_item_name": ref_name,
75+
"similarity_values": similarity_values,
76+
"real_time_seconds": real_time_seconds,
77+
"unix_times": unix_times,
78+
"is_active": is_active,
79+
"averaged_count": averaged_count,
80+
}
81+
)
82+
frames.append(traj_df)
83+
84+
if not frames:
85+
return DataFrame(None)
86+
87+
df_all = concat(frames, axis=0, ignore_index=True)
88+
df_all = df_all.rename(columns=self.RENAME_MAP)
89+
90+
idx_cols = [c for c in self.INDEX_COLS if c in df_all.columns]
91+
if idx_cols:
92+
df_all = df_all.set_index(idx_cols)
93+
94+
return df_all
95+
96+
def build_result(
97+
self,
98+
client: BaseClient, # noqa: ARG002
99+
data_id: str,
100+
data_type: str, # noqa: ARG002
101+
ts_df: DataFrame,
102+
*,
103+
workflow: str = "",
104+
window_span: float = 0.0,
105+
source_data_ids: Sequence[UUID | str] | None = None,
106+
) -> SimilarityTrajectoryResult:
107+
return SimilarityTrajectoryResult(
108+
source_id=data_id,
109+
workflow=workflow,
110+
window_span=window_span,
111+
timeseries_data=ts_df,
112+
source_data_ids=source_data_ids,
113+
)

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ class ResultIDs:
2727
metrology = ""
2828
photoluminescence = ""
2929
raman = ""
30+
similarity_workflow = "rheed_stationary"
31+
similarity_source_id = "bb3494b1-b5fb-4f3e-ac50-e4024f8aacf5"

0 commit comments

Comments
 (0)