Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ jobs:
strategy:
matrix:
os: [ubuntu-22.04, macos-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
exclude:
- os: macos-latest
python-version: '3.9'
- os: macos-latest
python-version: '3.10'
name: "Core, Python ${{ matrix.python-version }}, ${{ matrix.os }}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_poetry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
**Note**: Numbers like (\#123) point to closed Pull Requests on the fractal-tasks-core repository.

* Tasks:
* Refactor projection task to use ngio
* Dependencies:
* Add ngio
* CI:
* Remove Python 3.9 from the CI matrix

# 1.3.2
* Tasks:
* Add percentile-based rescaling to calculate registration task to make it more robust (\#848)
Expand Down
166 changes: 50 additions & 116 deletions fractal_tasks_core/tasks/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,12 @@
import logging
from typing import Any

import anndata as ad
import dask.array as da
import zarr
from ngio import NgffImage
from pydantic import validate_call
from zarr.errors import ContainsArrayError

from fractal_tasks_core.ngff import load_NgffImageMeta
from fractal_tasks_core.pyramids import build_pyramid
from fractal_tasks_core.roi import (
convert_ROIs_from_3D_to_2D,
)
from fractal_tasks_core.tables import write_table
from fractal_tasks_core.tables.v1 import get_tables_list_v1

from fractal_tasks_core.tasks.io_models import InitArgsMIP
from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod
from fractal_tasks_core.zarr_utils import OverwriteNotAllowedError


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,116 +49,61 @@ def projection(
logger.info(f"{method=}")

# Read image metadata
ngff_image = load_NgffImageMeta(init_args.origin_url)
# Currently not using the validation models due to wavelength_id issue
# See #681 for discussion
# new_attrs = ngff_image.model_dump(exclude_none=True)
# Current way to get the necessary metadata for MIP
group = zarr.open_group(init_args.origin_url, mode="r")
new_attrs = group.attrs.asdict()

# Create the zarr image with correct
new_image_group = zarr.group(zarr_url)
new_image_group.attrs.put(new_attrs)

# Load 0-th level
data_czyx = da.from_zarr(init_args.origin_url + "/0")
num_channels = data_czyx.shape[0]
chunksize_y = data_czyx.chunksize[-2]
chunksize_x = data_czyx.chunksize[-1]
logger.info(f"{num_channels=}")
logger.info(f"{chunksize_y=}")
logger.info(f"{chunksize_x=}")

# Loop over channels
accumulate_chl = []
for ind_ch in range(num_channels):
# Perform MIP for each channel of level 0
project_yx = da.stack(
[method.apply(data_czyx[ind_ch], axis=0)], axis=0
)
accumulate_chl.append(project_yx)
accumulated_array = da.stack(accumulate_chl, axis=0)

# Write to disk (triggering execution)
try:
accumulated_array.to_zarr(
f"{zarr_url}/0",
overwrite=init_args.overwrite,
dimension_separator="/",
write_empty_chunks=False,
)
except ContainsArrayError as e:
error_msg = (
f"Cannot write array to zarr group at '{zarr_url}/0', "
f"with {init_args.overwrite=} (original error: {str(e)}).\n"
"Hint: try setting overwrite=True."
original_ngff_image = NgffImage(init_args.origin_url)
orginal_image = original_ngff_image.get_image()

if orginal_image.is_2d or orginal_image.is_2d_time_series:
raise ValueError(
"The input image is 2D, "
"projection is only supported for 3D images."
)
logger.error(error_msg)
raise OverwriteNotAllowedError(error_msg)

# Starting from on-disk highest-resolution data, build and write to disk a
# pyramid of coarser levels
build_pyramid(
zarrurl=zarr_url,
on_disk_shape = orginal_image.on_disk_shape
logger.info(f"Original shape: {on_disk_shape=}")

on_disk_z_index = orginal_image.find_axis("z")

new_on_disk_shape = list(on_disk_shape)
new_on_disk_shape[on_disk_z_index] = 1

pixel_size = orginal_image.pixel_size
pixel_size.z = 1.0
logger.info(f"New shape: {new_on_disk_shape=}")

new_ngff_image = original_ngff_image.derive_new_image(
store=zarr_url,
name="MIP",
on_disk_shape=new_on_disk_shape,
pixel_sizes=pixel_size,
overwrite=init_args.overwrite,
num_levels=ngff_image.num_levels,
coarsening_xy=ngff_image.coarsening_xy,
chunksize=(1, 1, chunksize_y, chunksize_x),
)
new_image = new_ngff_image.get_image()

# Copy over any tables from the original zarr
# Generate the list of tables:
tables = get_tables_list_v1(init_args.origin_url)
roi_tables = get_tables_list_v1(init_args.origin_url, table_type="ROIs")
non_roi_tables = [table for table in tables if table not in roi_tables]

for table in roi_tables:
logger.info(
f"Reading {table} from "
f"{init_args.origin_url=}, convert it to 2D, and "
"write it back to the new zarr file."
)
new_ROI_table = ad.read_zarr(f"{init_args.origin_url}/tables/{table}")
old_ROI_table_attrs = zarr.open_group(
f"{init_args.origin_url}/tables/{table}"
).attrs.asdict()

# Convert 3D ROIs to 2D
pxl_sizes_zyx = ngff_image.get_pixel_sizes_zyx(level=0)
new_ROI_table = convert_ROIs_from_3D_to_2D(
new_ROI_table, pixel_size_z=pxl_sizes_zyx[0]
)
# Write new table
write_table(
new_image_group,
table,
new_ROI_table,
table_attrs=old_ROI_table_attrs,
overwrite=init_args.overwrite,
)
# Process the image
z_axis_index = orginal_image.dataset.axes_names.index("z")
source_dask = orginal_image.get_array(
mode="dask", preserve_dimensions=True
)

for table in non_roi_tables:
logger.info(
f"Reading {table} from "
f"{init_args.origin_url=}, and "
"write it back to the new zarr file."
)
new_non_ROI_table = ad.read_zarr(
f"{init_args.origin_url}/tables/{table}"
)
old_non_ROI_table_attrs = zarr.open_group(
f"{init_args.origin_url}/tables/{table}"
).attrs.asdict()

# Write new table
write_table(
new_image_group,
table,
new_non_ROI_table,
table_attrs=old_non_ROI_table_attrs,
overwrite=init_args.overwrite,
dest_dask = method.apply(dask_array=source_dask, axis=z_axis_index)
dest_dask = da.expand_dims(dest_dask, axis=z_axis_index)
new_image.set_array(dest_dask)
new_image.consolidate()
# Ends

# Copy over the tables
for roi_table in original_ngff_image.table.list(table_type="roi_table"):
table = original_ngff_image.table.get_table(roi_table)
mip_table = new_ngff_image.table.new(
roi_table, table_type="roi_table", overwrite=True
)
roi_list = []
for roi in table.rois:
roi.z_length = roi.z + 1
roi_list.append(roi)

mip_table.set_rois(roi_list, overwrite=True)
mip_table.consolidate()

# Generate image_list_updates
image_list_update_dict = dict(
Expand Down
Loading
Loading