1+ import dask
2+ import xarray as xr
13import spatialdata as sd
24import sopa
35import anndata as ad
46import pandas as pd
57import numpy as np
6- from scipy import sparse
78
89## VIASH START
910par = {
10- "input" : "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr" ,
11- "output" : "transcripts.zarr" ,
12-
11+ "input_ist" : "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr" ,
12+ "input_segmentation" : "resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr" ,
1313 "transcripts_key" : "transcripts" ,
14- "shapes_key" : "cell_boundaries" ,
15- "images_key" : "morphology_mip" ,
14+ "coordinate_system" : "global" ,
15+ "output" : "temp/comseg/transcripts.zarr" ,
16+
1617 "patch_width" : 1200 ,
1718 "patch_overlap" : 50 ,
1819 "transcript_patch_width" : 200 ,
@@ -72,19 +73,36 @@ def _add_csr(X_partitions, geo_df, partition, gene_column, gene_names ):
7273 return adata
7374
7475
75- # Read input SpatialData
76- sdata = sd .read_zarr (par ["input" ])
76+
77+ # Read input files
78+ print ('Reading input files' , flush = True )
79+ sdata = sd .read_zarr (par ['input_ist' ])
80+ sdata_segm = sd .read_zarr (par ['input_segmentation' ])
81+
82+
83+ # Convert the prior segmentation to polygons
84+ if isinstance (sdata_segm ["segmentation" ], xr .DataTree ):
85+ shapes_gdf = sopa .shapes .vectorize (sdata_segm ["segmentation" ]["scale0" ].image )
86+ else :
87+ shapes_gdf = sopa .shapes .vectorize (sdata_segm ["segmentation" ])
88+
89+ sdata ["segmentation_boundaries" ] = sd .models .ShapesModel .parse (
90+ shapes_gdf , transformations = sd .transformations .get_transformation (sdata_segm ["segmentation" ], get_all = True ).copy ()
91+ )
92+
93+ # Make patches
7794sopa .make_image_patches (sdata , patch_width = par ["patch_width" ], patch_overlap = par ["patch_overlap" ])
7895
7996transcript_patch_args = {
8097 "sdata" : sdata ,
8198 "write_cells_centroids" : True ,
8299 "patch_width" : par ["transcript_patch_width" ],
100+ "prior_shapes_key" : "segmentation_boundaries" ,
83101}
84- transcript_patch_args ["prior_shapes_key" ] = par ["shapes_key" ]
85102
86103sopa .make_transcript_patches (** transcript_patch_args )
87104
105+ # Run ComSeg
88106config = {
89107 "dict_scale" : {"x" : 1 , "y" : 1 , "z" : 1 },
90108 "mean_cell_diameter" : par ["mean_cell_diameter" ],
@@ -96,21 +114,45 @@ def _add_csr(X_partitions, geo_df, partition, gene_column, gene_names ):
96114 "gene_column" : par ["gene_column" ],
97115}
98116
99-
100117sopa .aggregation .transcripts ._count_transcripts_aligned = fixed_count_transcripts_aligned
118+ # sopa.settings.parallelization_backend = 'dask'
101119sopa .segmentation .comseg (sdata , config )
102120
121+ # Assign transcripts to cell ids
122+ sopa .spatial .assign_transcript_to_cell (
123+ sdata ,
124+ points_key = "transcripts" ,
125+ shapes_key = "comseg_boundaries" ,
126+ key_added = "cell_id" ,
127+ unassigned_value = 0
128+ )
129+
103130# Create output SpatialData
104- sd_output = sd .SpatialData ()
105131
106- cell_id_col = sdata ["transcripts" ][f"cell_id" ]
107- sdata .tables ["table" ]= ad .AnnData (obs = pd .DataFrame ({"cell_id" :cell_id_col }), var = sdata .tables ["table" ].var [[]])
108- sdata_new = sd .SpatialData (
109- points = sdata .points ,
110- tables = sdata .tables
111- )
132+ # Create objects for cells table
133+ print ('Creating objects for cells table' , flush = True )
134+ unique_cells = np .unique (sdata ["transcripts" ]["cell_id" ])
135+ zero_idx = np .where (unique_cells == 0 )
136+ if len (zero_idx [0 ]):
137+ unique_cells = np .delete (unique_cells , zero_idx [0 ][0 ])
138+ cell_id_col = pd .Series (unique_cells , name = 'cell_id' , index = unique_cells )
139+
140+ # Create transcripts only sdata
141+ print ('Subsetting to transcripts cell id data' , flush = True )
142+ sdata_transcripts_only = sd .SpatialData (
143+ points = {
144+ "transcripts" : sdata ['transcripts' ]
145+ },
146+ tables = {
147+ "table" : ad .AnnData (
148+ obs = pd .DataFrame (cell_id_col ),
149+ var = sdata .tables ["table" ].var [[]]
150+ )
151+ }
152+ )
153+
112154
113155output_path = par ['output' ]
114- sdata_new .write (output_path , overwrite = True )
156+ sdata_transcripts_only .write (output_path , overwrite = True )
115157
116158
0 commit comments