1+ """
2+ This script will pre-process the raw AGEA volumes and write a newer version to disk.
3+ The pre-processing steps include:
4+ - merge hemispheres
5+ - impute using PPCA
6+ - remove curtaining effect trying to minimize the slice to slice amplitude variation per Cosmos regions
7+ """
8+
9+ from pathlib import Path
10+ import tqdm
11+ import joblib
12+
13+ import numpy as np
14+ import pandas as pd
15+ import scipy .ndimage
16+ import seaborn as sns
17+ import matplotlib .pyplot as plt
18+
19+ import scipy .optimize
20+ from sklearn .model_selection import train_test_split
21+ from sklearn .neural_network import MLPClassifier
22+ from sklearn .decomposition import PCA
23+ from sklearn .preprocessing import StandardScaler
24+
25+ import pyppca # pip install pyppca
26+ from iblatlas .genomics import agea
27+
28+ df_genes , gene_expression_volumes , atlas_agea = agea .load ()
29+ mask_brain = atlas_agea .label == 0
30+
31+ folder_volumes = Path ('/datadisk/Data/2025/denoise_agea' )
32+ folder_volumes = Path ('/mnt/s1/2025/denoise_agea' )
33+
34+
35+ def convert_to_dual_volume (input_memmap , output_memmap ):
36+ # Creates the dual volume
37+ for igene in tqdm .tqdm (np .arange (input_memmap .shape [0 ])):
38+ # igene = 625
39+ gvol = np .copy (input_memmap [igene ]).astype (float )
40+ gvol [gvol < 0 ] = np .nan
41+ gvol_dual = np .nanmean (np .stack ((gvol , np .flip (gvol , axis = 0 )), axis = 3 ), axis = 3 )
42+ gvol_dual [np .isnan (gvol_dual )] = - 1
43+ output_memmap [igene ] = gvol_dual .astype (np .float16 ) # isl = 31
44+
45+
46+ def compute_agea_pca (gene_expression_volumes , atlas_agea , pca_n_components = 50 ):
47+ """
48+ Compute PCA embedding for AGEA and return the embedding and the indices of the voxels within the brain volume.
49+
50+ :param gene_expression_volumes: numpy array of shape (n_genes, n_ml, n_dv, n_ap)
51+ :param atlas_agea: AGEA atlas object
52+ :param pca_n_components: Number of components for PCA
53+ :return: PCA embedding (n_voxels_embedding, n_pca_components)
54+ and flat indices of selected voxels within the brain volume np.array(int64) (n_voxels_embedding)
55+ :return:
56+ """
57+ inside_idx = np .where (atlas_agea .label .flatten () != 0 )[0 ]
58+ ng = gene_expression_volumes .shape [0 ]
59+ # sel = atlas_agea.label.flatten() != 0 # remove void voxels
60+ # reshape in a big array nexp x nvoxels this takes a little while
61+ gexps = np .copy (gene_expression_volumes .reshape ((ng , - 1 ))[:, inside_idx ].astype (np .float32 ).transpose ())
62+ p_missing = np .mean (gexps < 0 , axis = 1 ) # this is the proportion of missing genes
63+ gexps = gexps [p_missing < 0.6 , :] # we select only the voxels with less than 60% missing genes
64+ embedding_idx = inside_idx [p_missing < 0.6 ] # we select only voxels within the brain volume
65+ # % run PCA on gexps
66+ scaler = StandardScaler ()
67+ gexps = scaler .fit_transform (gexps )
68+ pca = PCA (n_components = pca_n_components )
69+ embedding = pca .fit_transform (gexps )
70+ return embedding , embedding_idx
71+
72+
73+ def compute_agea_ppca (input_memmap , atlas_agea , output_memmap = None , ppca_n_components = 50 ):
74+ inside_idx = np .where (atlas_agea .label .flatten () != 0 )[0 ]
75+ outside_idx = np .unravel_index (np .where (atlas_agea .label .flatten () == 0 ), input_memmap .shape [1 :])
76+ ng = input_memmap .shape [0 ]
77+ gexps = input_memmap .reshape ((ng , - 1 ))[:, inside_idx ].astype (np .float32 ).transpose ()
78+ gexps [gexps < 0 ] = np .nan
79+ C , ss , M , X , Ye = pyppca .ppca (gexps , d = ppca_n_components , dia = True )
80+ output_memmap = np .copy (input_memmap ) if output_memmap is None else output_memmap
81+ output_memmap [:, * np .unravel_index (inside_idx , input_memmap .shape [1 :])] = Ye .T
82+ output_memmap [:, * outside_idx ] = input_memmap [:, * outside_idx ]
83+ return output_memmap
84+
85+
86+ def train_region_predictor (atlas_agea , embedding , embedding_idx , mapping = None ):
87+ """
88+ From the PCA embedding, predict the cosmos level label using a MLP.
89+ split the data in training and testing sets
90+ Accuracy: 0.7351404310907903 for allen regions
91+ Accuracy: 0.9329686479425212 for cosmos regions
92+ :param atlas_agea:
93+ :param embedding:
94+ :param embedding_idx:
95+ :param mapping:
96+ :return:
97+ """
98+ aids = np .abs (atlas_agea .regions .id [atlas_agea .label .flatten ()[embedding_idx ]])
99+ if mapping is not None :
100+ labels = atlas_agea .regions .remap (aids , source_map = 'Allen' , target_map = 'Cosmos' )
101+ else :
102+ labels = aids
103+ X_train , X_test , y_train , y_test = train_test_split (embedding , labels , test_size = 0.2 , random_state = 42 )
104+ mlp = MLPClassifier (hidden_layer_sizes = (50 ,), max_iter = 300 )
105+ mlp .fit (X_train , y_train )
106+ # y_pred = mlp.predict(X_test)
107+ accuracy = mlp .score (X_test , y_test )
108+ print ("Accuracy:" , accuracy )
109+ return accuracy
110+
111+
112+ def curtaining (volume_agea , max_ratio = 2 ):
113+ """
114+ Removes curtaining effect from a single gene expression volume
115+ :param volume_agea:
116+ :param max_ratio:
117+ :return:
118+ """
119+
120+ def objective_function (weights , expression , counts ):
121+ weighted_X = np .log (expression * weights [:, np .newaxis ])
122+ std_x = np .abs (weighted_X - np .nanmedian (weighted_X , axis = 0 )) # Compute the weighted mean of the expression
123+ return np .nansum (std_x * counts )
124+
125+ df_vol = pd .DataFrame ({
126+ 'iy' : np .tile (np .arange (volume_agea .shape [2 ]), np .prod (volume_agea .shape [:2 ])),
127+ 'gene' : volume_agea .flatten (),
128+ 'rindex' : atlas_agea .label .flatten ()}
129+ )
130+ df_vol = df_vol .loc [np .logical_and (df_vol ['gene' ] >= 0 , df_vol ['rindex' ] != 0 )]
131+ df_vol ['allen_id' ] = np .abs (atlas_agea .regions .id [df_vol ['rindex' ]])
132+ df_vol ['cosmos_id' ] = atlas_agea .regions .remap (df_vol ['allen_id' ].to_numpy (), source_map = 'Allen' ,
133+ target_map = 'Cosmos' )
134+ # here maybe median is more robust than mean for outliers
135+ df_slices = df_vol .pivot_table (index = 'iy' , values = 'gene' , aggfunc = ['median' , 'count' ], columns = 'cosmos_id' )
136+ expression = df_slices .loc [:, 'median' ].to_numpy ()
137+ counts = df_slices .loc [:, 'count' ].to_numpy ()
138+ n_slices = expression .shape [0 ]
139+ # Optimization
140+ result = scipy .optimize .minimize (
141+ objective_function ,
142+ np .ones (n_slices ),
143+ args = (expression , counts ),
144+ method = 'L-BFGS-B' ,
145+ bounds = [(1 / max_ratio , max_ratio ) for _ in range (n_slices )]
146+ )
147+ coronal_weights = np .ones (atlas_agea .bc .ny )
148+ coronal_weights [df_slices .index ] = result ['x' ]
149+ return volume_agea * (coronal_weights [np .newaxis , np .newaxis , :])
150+
151+
152+ def curtaining_parallel (input_memmap , output_memmap = None ):
153+ """
154+ Try to remove the coronal slice curtaining effect in the AGEA volume.
155+ The idea is to find the coronal slices weights that minimize the standard deviation of the
156+ expression accross voxels within a cosmos regions
157+ :return:
158+ """
159+ output_memmap = np .copy (input_memmap ) if output_memmap is None else output_memmap
160+
161+ def compute_single_gene (igene ):
162+ output_memmap [igene ] = curtaining (input_memmap [igene ])
163+ pass
164+
165+ ng = input_memmap .shape [0 ]
166+ jobs = (joblib .delayed (compute_single_gene )(igene ) for igene in np .arange (ng ))
167+ list (tqdm .tqdm (joblib .Parallel (n_jobs = joblib .cpu_count () - 1 , return_as = 'generator' )(jobs ), total = ng ))
168+
169+
170+ # %%
171+ RECOMPUTE = False
172+ files = {
173+ 'dual' : folder_volumes .joinpath ('dual_memmap.bin' ),
174+ 'ppca' : folder_volumes .joinpath ('ppca.bin' ),
175+ 'dual_ppca' : folder_volumes .joinpath ('dual_ppca.bin' ),
176+ 'ppca_dual' : folder_volumes .joinpath ('ppca_dual.bin' ),
177+ 'dual_ppca_curtain' : folder_volumes .joinpath ('dual_ppca_curtain.bin' ),
178+ 'ppca_dual_curtain' : folder_volumes .joinpath ('ppca_dual_curtain.bin' )
179+ }
180+ if RECOMPUTE :
181+ print ('Create memmaps...' )
182+ memmaps = {k : np .memmap (v , dtype = np .float16 , mode = 'w+' , offset = 0 ,
183+ shape = gene_expression_volumes .shape ) for k , v in files .items ()}
184+ convert_to_dual_volume (gene_expression_volumes , memmaps ['dual' ])
185+ compute_agea_ppca (memmaps ['dual' ], atlas_agea , output_memmap = memmaps ['dual_ppca' ])
186+ compute_agea_ppca (gene_expression_volumes , atlas_agea , output_memmap = memmaps ['ppca' ])
187+ convert_to_dual_volume (memmaps ['ppca' ], memmaps ['ppca_dual' ])
188+ curtaining_parallel (memmaps ['dual_ppca' ], output_memmap = memmaps ['dual_ppca_curtain' ])
189+ curtaining_parallel (memmaps ['ppca_dual' ], output_memmap = memmaps ['ppca_dual_curtain' ])
190+ else :
191+ memmaps = {k : np .memmap (v , dtype = np .float16 , mode = 'r+' , offset = 0 ,
192+ shape = gene_expression_volumes .shape ) for k , v in files .items ()}
193+ memmaps ['baseline' ] = gene_expression_volumes
194+
195+ # %%
196+ file_results = folder_volumes .joinpath ('results.csv' )
197+ if file_results .exists () and not RECOMPUTE :
198+ df_results = pd .read_csv (file_results )
199+ to_skip = list (df_results .volume .unique ())
200+ else :
201+ df_results = pd .DataFrame (columns = ['volume' , 'region' , 'accuracy' ])
202+ to_skip = []
203+
204+ results = []
205+ for k , vol in memmaps .items ():
206+ if k in to_skip :
207+ continue
208+ print (f'Compute PCA embeddings for { k } ...' )
209+ embedding , embedding_idx = compute_agea_pca (memmaps [k ], atlas_agea )
210+ print (f'Compute Allen region prediction for { k } ...' )
211+ accuracy = train_region_predictor (atlas_agea , embedding , embedding_idx , mapping = None )
212+ results .append (dict (volume = k , region = 'Allen' , accuracy = accuracy ))
213+ print (f'Compute Cosmos region prediction for { k } ...' )
214+ accuracy = train_region_predictor (atlas_agea , embedding , embedding_idx , mapping = 'Cosmos' )
215+ results .append (dict (volume = k , region = 'Cosmos' , accuracy = accuracy ))
216+
217+ if len (results ) > 0 :
218+ df_results = pd .concat ([df_results , pd .DataFrame (results )], ignore_index = True , axis = 0 )
219+ df_results .to_csv (file_results , index = False )
220+
221+
222+ # %% plot the accuracy for each method and region
223+ volume_order = ['baseline' , 'ppca' , 'dual' , 'ppca_dual' , 'ppca_dual_curtain' ]
224+ sns .catplot (x = 'region' , y = 'accuracy' , hue = 'volume' , data = df_results , kind = 'bar' , hue_order = volume_order )
225+ plt .show ()
226+
227+ # aws s3 cp /mnt/s1/2025/denoise_agea/ppca_dual_curtain.bin s3://ibl-brain-wide-map-public/atlas/agea/gene-expression-processed.bin # NOQA
228+ # touch /mnt/s1/2025/denoise_agea/2025-03-18.version
229+ # aws s3 cp /mnt/s1/2025/denoise_agea/2025-03-18.version s3://ibl-brain-wide-map-public/atlas/agea/2025-03-18.version # NOQA
0 commit comments