Skip to content

Commit d98249b

Browse files
authored
Merge pull request #58 from atomscale-ai/enhancement/times_series_align
Fix alignment of time series
2 parents c863854 + 9da54a2 commit d98249b

File tree

2 files changed

+216
-80
lines changed

2 files changed

+216
-80
lines changed

src/atomscale/client.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -434,15 +434,13 @@ def get_physical_sample(
434434
*,
435435
include_organization_data: bool = True,
436436
align: bool | str = False,
437-
resample: str | None = None,
438437
) -> PhysicalSampleResult:
439438
"""Get all data for a physical sample.
440439
441440
Args:
442441
physical_sample_id: Identifier of the physical sample.
443442
include_organization_data: Whether to include organization data. Defaults to True.
444443
align: Whether to align timeseries data. If truthy, an aligned DataFrame is returned.
445-
resample: Optional pandas resample rule applied after alignment.
446444
"""
447445
physical_samples: list[dict] | None = self._get( # type: ignore # noqa: PGH003
448446
sub_url="physical_samples/",
@@ -468,11 +466,7 @@ def get_physical_sample(
468466
if isinstance(align, str):
469467
join_how = align
470468

471-
ts_aligned = (
472-
align_timeseries(results, how=join_how, resample=resample)
473-
if align
474-
else None
475-
)
469+
ts_aligned = align_timeseries(results, how=join_how) if align else None
476470

477471
non_timeseries = [
478472
r
@@ -501,15 +495,13 @@ def get_project(
501495
*,
502496
include_organization_data: bool = True,
503497
align: bool | str = False,
504-
resample: str | None = None,
505498
) -> ProjectResult:
506499
"""Get all data grouped by physical sample for a project.
507500
508501
Args:
509502
project_id: Identifier of the project.
510503
include_organization_data: Whether to include organization data. Defaults to True.
511504
align: Whether to align timeseries at the project level. Defaults to False.
512-
resample: Optional pandas resample rule applied after alignment.
513505
"""
514506
# Get physical samples associated with the project, then fetch data per sample.
515507
project_samples: list[dict] = (
@@ -519,39 +511,27 @@ def get_project(
519511
return ProjectResult(project_id, None, [], None)
520512

521513
sample_results: list[PhysicalSampleResult] = []
514+
all_results: list = []
522515
for sample in project_samples:
523516
sid = sample.get("id")
524517
if not sid:
525518
continue
519+
# For project-level alignment we align once across all entries, so
520+
# skip per-sample alignment when align=True.
521+
sample_align = False if align else align
526522
sample_results.append(
527523
self.get_physical_sample(
528524
sid,
529525
include_organization_data=include_organization_data,
530-
align=align,
531-
resample=resample,
526+
align=sample_align,
532527
)
533528
)
529+
if sample_results[-1].data_results:
530+
all_results.extend(sample_results[-1].data_results)
534531

535532
project_aligned = None
536533
if align:
537-
frames = []
538-
for sample in sample_results:
539-
if sample.aligned_timeseries is None:
540-
continue
541-
renamed = sample.aligned_timeseries.copy()
542-
renamed.columns = pd.MultiIndex.from_tuples(
543-
[
544-
(sample.physical_sample_id, *tuple(col))
545-
if isinstance(col, tuple)
546-
else (sample.physical_sample_id, col)
547-
for col in renamed.columns
548-
]
549-
)
550-
frames.append(renamed)
551-
if frames:
552-
project_aligned = frames[0]
553-
for frame in frames[1:]:
554-
project_aligned = project_aligned.join(frame, how="outer")
534+
project_aligned = align_timeseries(all_results, how="outer")
555535

556536
project_name = None
557537
return ProjectResult(
@@ -654,7 +634,13 @@ def _get_result_data(
654634
raw = provider.fetch_raw(self, data_id)
655635
ts_df = provider.to_dataframe(raw)
656636

657-
return provider.build_result(self, data_id, data_type, ts_df)
637+
result_obj = provider.build_result(self, data_id, data_type, ts_df)
638+
if catalogue_entry:
639+
# Store upload datetime for alignment fallback when only relative time is available.
640+
upload_dt = catalogue_entry.get("upload_datetime")
641+
if upload_dt:
642+
result_obj.upload_datetime = upload_dt
643+
return result_obj
658644

659645
# Fallback for unknown/unsupported data types
660646
return UnknownResult(

0 commit comments

Comments
 (0)