|
| 1 | +import spatialdata as sd |
| 2 | +import sopa |
| 3 | +import anndata as ad |
| 4 | +import pandas as pd |
| 5 | +import numpy as np |
| 6 | +from scipy import sparse |
| 7 | + |
| 8 | +## VIASH START |
| 9 | +par = { |
| 10 | + "input": "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr", |
| 11 | + "output": "transcripts.zarr", |
| 12 | + |
| 13 | + "transcripts_key": "transcripts", |
| 14 | + "shapes_key": "cell_boundaries", |
| 15 | + "images_key": "morphology_mip", |
| 16 | + "patch_width": 1200, |
| 17 | + "patch_overlap": 50, |
| 18 | + "transcript_patch_width": 200, |
| 19 | + "mean_cell_diameter": 15.0, |
| 20 | + "max_cell_radius": 25.0, |
| 21 | + "alpha": 0.5, |
| 22 | + "min_rna_per_cell": 5, |
| 23 | + "gene_column": "feature_name", |
| 24 | + "norm_vector": False, |
| 25 | + "allow_disconnected_polygon": True, |
| 26 | +} |
| 27 | +## VIASH END |
| 28 | + |
| 29 | +def fixed_count_transcripts_aligned(geo_df, points, value_key): |
| 30 | + """ |
| 31 | + The same function as sopa.aggregation.transcripts._count_transcripts_aligned. |
| 32 | + Minor change just the matrix X is converted to csr_matrix, to avoid bug error in comseg call |
| 33 | +
|
| 34 | + """ |
| 35 | + from scipy.sparse import csr_matrix |
| 36 | + from anndata import AnnData |
| 37 | + from dask.diagnostics import ProgressBar |
| 38 | + from functools import partial |
| 39 | + from sopa._settings import settings |
| 40 | + import geopandas as gpd |
| 41 | + def _add_csr(X_partitions, geo_df, partition, gene_column, gene_names ): |
| 42 | + if settings.gene_exclude_pattern is not None: |
| 43 | + partition = partition[~partition[gene_column].str.match(settings.gene_exclude_pattern, case=False, na=False)] |
| 44 | + |
| 45 | + points_gdf = gpd.GeoDataFrame(partition, geometry=gpd.points_from_xy(partition["x"], partition["y"])) |
| 46 | + joined = geo_df.sjoin(points_gdf) |
| 47 | + cells_indices, column_indices = joined.index, joined[gene_column].cat.codes |
| 48 | + cells_indices = cells_indices[column_indices >= 0] |
| 49 | + column_indices = column_indices[column_indices >= 0] |
| 50 | + X_partition = csr_matrix((np.full(len(cells_indices), 1), (cells_indices, column_indices)), |
| 51 | + shape=(len(geo_df), len(gene_names)), |
| 52 | + ) |
| 53 | + X_partitions.append(X_partition) |
| 54 | + |
| 55 | + |
| 56 | + points[value_key] = points[value_key].astype("category").cat.as_known() |
| 57 | + gene_names = points[value_key].cat.categories.astype(str) |
| 58 | + X = csr_matrix((len(geo_df), len(gene_names)), dtype=int) |
| 59 | + adata = AnnData(X=X, var=pd.DataFrame(index=gene_names)) |
| 60 | + adata.obs_names = geo_df.index.astype(str) |
| 61 | + geo_df = geo_df.reset_index() |
| 62 | + X_partitions = [] |
| 63 | + with ProgressBar(): |
| 64 | + points.map_partitions( |
| 65 | + partial(_add_csr, X_partitions, geo_df, gene_column=value_key, gene_names=gene_names), |
| 66 | + meta=(), |
| 67 | + ).compute() |
| 68 | + for X_partition in X_partitions: |
| 69 | + adata.X += X_partition |
| 70 | + if settings.gene_exclude_pattern is not None: |
| 71 | + adata = adata[:, ~adata.var_names.str.match(settings.gene_exclude_pattern, case=False, na=False)].copy() |
| 72 | + return adata |
| 73 | + |
| 74 | + |
| 75 | +# Read input SpatialData |
| 76 | +sdata = sd.read_zarr(par["input"]) |
| 77 | +sopa.make_image_patches(sdata, patch_width=par["patch_width"], patch_overlap=par["patch_overlap"]) |
| 78 | + |
| 79 | +transcript_patch_args = { |
| 80 | + "sdata": sdata, |
| 81 | + "write_cells_centroids": True, |
| 82 | + "patch_width": par["transcript_patch_width"], |
| 83 | +} |
| 84 | +transcript_patch_args["prior_shapes_key"] = par["shapes_key"] |
| 85 | + |
| 86 | +sopa.make_transcript_patches(**transcript_patch_args) |
| 87 | + |
| 88 | +config = { |
| 89 | + "dict_scale": {"x": 1, "y": 1, "z": 1}, |
| 90 | + "mean_cell_diameter": par["mean_cell_diameter"], |
| 91 | + "max_cell_radius": par["max_cell_radius"], |
| 92 | + "norm_vector": par["norm_vector"], |
| 93 | + "alpha": par["alpha"], |
| 94 | + "allow_disconnected_polygon": par["allow_disconnected_polygon"], |
| 95 | + "min_rna_per_cell": par["min_rna_per_cell"], |
| 96 | + "gene_column": par["gene_column"], |
| 97 | +} |
| 98 | + |
| 99 | + |
| 100 | +sopa.aggregation.transcripts._count_transcripts_aligned = fixed_count_transcripts_aligned |
| 101 | +sopa.segmentation.comseg(sdata, config) |
| 102 | + |
| 103 | +# Create output SpatialData |
| 104 | +sd_output = sd.SpatialData() |
| 105 | + |
| 106 | +cell_id_col = sdata["transcripts"][f"cell_id"] |
| 107 | +sdata.tables["table"]=ad.AnnData(obs=pd.DataFrame({"cell_id":cell_id_col}), var=sdata.tables["table"].var[[]]) |
| 108 | +sdata_new = sd.SpatialData( |
| 109 | + points=sdata.points, |
| 110 | + tables=sdata.tables |
| 111 | +) |
| 112 | + |
| 113 | +output_path = par['output'] |
| 114 | +sdata_new.write(output_path, overwrite=True) |
| 115 | + |
| 116 | + |
0 commit comments