|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +from pathlib import Path |
| 4 | +import xarray as xr |
| 5 | +import dask |
| 6 | +import numpy as np |
| 7 | +import pandas as pd |
| 8 | +import anndata as ad |
| 9 | +import spatialdata as sd |
| 10 | +import sopa |
| 11 | + |
| 12 | + |
| 13 | +## VIASH START |
| 14 | +# Note: this section is auto-generated by viash at runtime. To edit it, make changes |
| 15 | +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. |
| 16 | +par = { |
| 17 | + 'input_ist': 'resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr', |
| 18 | + 'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr', |
| 19 | + 'transcripts_key': 'transcripts', |
| 20 | + 'coordinate_system': 'global', |
| 21 | + 'output': './temp/methods/baysor/baysor_assigned_transcripts.zarr', |
| 22 | + |
| 23 | + 'force_2d': 'false', |
| 24 | + 'min_molecules_per_cell': 50, |
| 25 | + 'scale': -1.0, #NOTE: For parameter selection see https://github.com/gustaveroussy/sopa/tree/main/workflow/config |
| 26 | + 'scale_std': "25%", |
| 27 | + 'n_clusters': 4, |
| 28 | + 'prior_segmentation_confidence': 0.8, |
| 29 | +} |
| 30 | +meta = { |
| 31 | + 'name': 'baysor_transcript_assignment', |
| 32 | + 'temp_dir': "./temp/methods/baysor", |
| 33 | + 'cpus': 4, |
| 34 | +} |
| 35 | +## VIASH END |
| 36 | + |
| 37 | +TMP_DIR = Path(meta["temp_dir"] or "/tmp") |
| 38 | +TMP_DIR.mkdir(parents=True, exist_ok=True) |
| 39 | + |
| 40 | +CONFIG_TOML = TMP_DIR / "config.toml" |
| 41 | + |
| 42 | + |
| 43 | +############################## |
| 44 | +# Basic assignment for prior # |
| 45 | +############################## |
| 46 | + |
| 47 | +# Sopa takes the prior segmentation as cell_id column in the transcripts table. |
| 48 | +# Generate this column with basic assignment: |
| 49 | +print('Reading input files', flush=True) |
| 50 | +sdata = sd.read_zarr(par['input_ist']) |
| 51 | +sdata_segm = sd.read_zarr(par['input_segmentation']) |
| 52 | + |
| 53 | +# Check if coordinate system is available in input data |
| 54 | +transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys() |
| 55 | +assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." |
| 56 | +segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys() |
| 57 | +assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." |
| 58 | + |
| 59 | +print('Transforming transcripts coordinates', flush=True) |
| 60 | +transcripts = sd.transform(sdata[par['transcripts_key']], to_coordinate_system=par['coordinate_system']) |
| 61 | + |
| 62 | +# In case of a translation transformation of the segmentation (e.g. crop of the data), we need to adjust the transcript coordinates |
| 63 | +trans = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True)[par['coordinate_system']].inverse() |
| 64 | +transcripts = sd.transform(transcripts, trans, par['coordinate_system']) |
| 65 | + |
| 66 | +print('Assigning transcripts to cell ids', flush=True) |
| 67 | +y_coords = transcripts.y.compute().to_numpy(dtype=np.int64) |
| 68 | +x_coords = transcripts.x.compute().to_numpy(dtype=np.int64) |
| 69 | +if isinstance(sdata_segm["segmentation"], xr.DataTree): |
| 70 | + label_image = sdata_segm["segmentation"]["scale0"].image.to_numpy() |
| 71 | +else: |
| 72 | + label_image = sdata_segm["segmentation"].to_numpy() |
| 73 | +cell_id_dask_series = dask.dataframe.from_dask_array( |
| 74 | + dask.array.from_array( |
| 75 | + label_image[y_coords, x_coords], chunks=tuple(sdata[par['transcripts_key']].map_partitions(len).compute()) |
| 76 | + ), |
| 77 | + index=sdata[par['transcripts_key']].index |
| 78 | +) |
| 79 | +sdata[par['transcripts_key']]["cell_id"] = cell_id_dask_series |
| 80 | + |
| 81 | + |
| 82 | +######################## |
| 83 | +# Run baysor with sopa # |
| 84 | +######################## |
| 85 | + |
| 86 | +# Create reduced sdata |
| 87 | +sdata_sopa = sd.SpatialData( |
| 88 | + points={ |
| 89 | + "transcripts": sdata[par['transcripts_key']] |
| 90 | + }, |
| 91 | +) |
| 92 | + |
| 93 | +# Write config to toml |
| 94 | +print('Writing config to toml', flush=True) |
| 95 | +toml_str = f"""[data] |
| 96 | +x = "x" |
| 97 | +y = "y" |
| 98 | +z = "z" |
| 99 | +gene = "feature_name" |
| 100 | +force_2d = {par['force_2d']} |
| 101 | +min_molecules_per_cell = {int(par['min_molecules_per_cell'])} |
| 102 | +exclude_genes = "" |
| 103 | +
|
| 104 | +[segmentation] |
| 105 | +scale = {float(par['scale'])} |
| 106 | +scale_std = "{par['scale_std']}" |
| 107 | +n_clusters = {int(par['n_clusters'])} |
| 108 | +prior_segmentation_confidence = {float(par['prior_segmentation_confidence'])} |
| 109 | +""" |
| 110 | +with open(CONFIG_TOML, "w") as toml_file: |
| 111 | + toml_file.write(toml_str) |
| 112 | + |
| 113 | + |
| 114 | + |
| 115 | +# Make transcript patches |
| 116 | +sopa.make_transcript_patches(sdata_sopa, patch_width=2000, patch_overlap=50, prior_shapes_key="cell_id") |
| 117 | +sopa.settings.parallelization_backend = "dask" |
| 118 | + |
| 119 | +# Run baysor |
| 120 | +sopa.segmentation.baysor(sdata_sopa, config=str(CONFIG_TOML)) |
| 121 | + |
| 122 | +# Assign transcripts to cell ids |
| 123 | +sopa.spatial.assign_transcript_to_cell( |
| 124 | + sdata_sopa, |
| 125 | + points_key="transcripts", |
| 126 | + shapes_key="baysor_boundaries", |
| 127 | + key_added="cell_id", |
| 128 | + unassigned_value=0 |
| 129 | +) |
| 130 | + |
| 131 | + |
| 132 | + |
| 133 | +# Create objects for cells table |
| 134 | +print('Creating objects for cells table', flush=True) |
| 135 | +#create new .obs for cells based on the segmentation output (corresponding with the transcripts 'cell_id') |
| 136 | +unique_cells = np.unique(sdata_sopa["transcripts"]["cell_id"]) |
| 137 | + |
| 138 | +# check if a '0' (noise/background) cell is in cell_id and remove |
| 139 | +zero_idx = np.where(unique_cells == 0) |
| 140 | +if len(zero_idx[0]): unique_cells=np.delete(unique_cells, zero_idx[0][0]) |
| 141 | + |
| 142 | +#transform into pandas series and check |
| 143 | +cell_id_col = pd.Series(unique_cells, name='cell_id', index=unique_cells) |
| 144 | +assert 0 not in cell_id_col, "Found '0' in cell_id column of assingment output cell matrix" |
| 145 | + |
| 146 | + |
| 147 | +# Create transcripts only sdata |
| 148 | +print('Subsetting to transcripts cell id data', flush=True) |
| 149 | +sdata_transcripts_only = sd.SpatialData( |
| 150 | + points={ |
| 151 | + "transcripts": sdata_sopa['transcripts'] |
| 152 | + }, |
| 153 | + tables={ |
| 154 | + "table": ad.AnnData( |
| 155 | + obs=pd.DataFrame(cell_id_col), |
| 156 | + var=sdata.tables["table"].var[[]] |
| 157 | + ) |
| 158 | + } |
| 159 | +) |
| 160 | + |
| 161 | +# Write output |
| 162 | +print('Write transcripts with cell ids', flush=True) |
| 163 | +if os.path.exists(par["output"]): |
| 164 | + shutil.rmtree(par["output"]) |
| 165 | + |
| 166 | +sdata_transcripts_only.write(par['output']) |
0 commit comments