From d6de2035d3e318c805f89f78529bad42065e4674 Mon Sep 17 00:00:00 2001 From: Elihei2 Date: Thu, 24 Jul 2025 18:33:16 +0200 Subject: [PATCH 1/3] added negative sampling from neighbor bds, and mutually exclusive gene pair loss --- .DS_Store | Bin 8196 -> 10244 bytes scripts/0_data_creation_5k_nucleus.py | 4 +- scripts/1_train_5k.py | 11 +- scripts/4_xenium_explorer.py | 41 ++ scripts/batch_run_xenium.zip | Bin 0 -> 8921 bytes scripts/batch_run_xenium/create_data_batch.py | 73 +++ scripts/batch_run_xenium/create_data_batch.sh | 43 ++ .../batch_run_xenium/create_data_batch_BrM.sh | 43 ++ .../batch_run_xenium/create_data_batch_EwS.sh | 43 ++ .../batch_run_xenium/create_data_batch_GB.sh | 43 ++ scripts/batch_run_xenium/predict_batch.py | 68 +++ scripts/batch_run_xenium/predict_batch.sh | 79 ++++ scripts/batch_run_xenium/train_batch.py | 54 +++ scripts/batch_run_xenium/train_batch.sh | 18 + scripts/batch_run_xenium/train_batch_EwS.sh | 18 + scripts/create_data_fast_sample.py | 71 ++- scripts/import os.py | 273 +++++++++++ scripts/predict_5k_yiheng.py | 73 +++ scripts/predict_model_sample.py | 6 +- scripts/predict_project24.py | 73 +++ scripts/train_MNG_5k.sh | 8 +- scripts/train_model_sample.py | 26 +- src/segger/data/parquet/_utils.py | 136 +++++- src/segger/data/parquet/sample.py | 81 +++- src/segger/prediction/predict_parquet.py | 2 +- src/segger/training/train.py | 38 +- src/segger/validation/xenium_explorer.py | 431 +++++++++--------- 27 files changed, 1441 insertions(+), 315 deletions(-) create mode 100644 scripts/4_xenium_explorer.py create mode 100644 scripts/batch_run_xenium.zip create mode 100644 scripts/batch_run_xenium/create_data_batch.py create mode 100644 scripts/batch_run_xenium/create_data_batch.sh create mode 100644 scripts/batch_run_xenium/create_data_batch_BrM.sh create mode 100644 scripts/batch_run_xenium/create_data_batch_EwS.sh create mode 100644 scripts/batch_run_xenium/create_data_batch_GB.sh create mode 100644 scripts/batch_run_xenium/predict_batch.py create mode 100644 scripts/batch_run_xenium/predict_batch.sh create mode 100644 scripts/batch_run_xenium/train_batch.py create mode 100644 scripts/batch_run_xenium/train_batch.sh create mode 100644 scripts/batch_run_xenium/train_batch_EwS.sh create mode 100644 scripts/import os.py create mode 100644 scripts/predict_5k_yiheng.py create mode 100644 scripts/predict_project24.py diff --git a/.DS_Store b/.DS_Store index 580c4d3fb5d1c959a19f59206a2155441d9890f0..ec2e1f4516b97f824ad79d48cde8596ceca81289 100644 GIT binary patch delta 1279 zcmaKsTTC2P7{|Z=0Nrz7fdhpFI;>Dz8eEokrA2IuP%d4eLV>`_rN+9;Ozp&F_A;|e z^`SJ1(Zt%=j;2qwsU|)cVSQC9{iSd@$_~28MCdR~DyqsZH6EJa-nQy-D z%y;I@{C+vh2bRa{i3t5U?S7(WdCTPKM&Uo337p-z6{RcVS1VC^*)P8HmHWt|dCHPO zX|kPwc-Gh15eYYUMx2YHch9;}vXWMQAdyTZh>Er((n-UzBattm6cv|j@-fcK)ajYo zxpUQhe8$#K*ptRHx?@&{NBy(Cp*0 zacf>THN$*z&d_wzHjWv3mRDhfPnq!reN_bN>KnGGs`F8G{nAV>VPtZK{j}J|m^WdQ zx|NF$*kfAU)<;vC+pevRs7-1MV@KTjKvtJmugqK35EpSaWyxiFQ?&D4Tv&YSDZ?zZ zcJMt+XgrpTC-kH@TNmHUmc%1`Intl|gU#jp4c zzbnN`iQ-e3qAHC_ixN_HOJn7sz>9RG5DP&WlsEsiGws=_WY*c~FX$PJK0GwMQO~x` zU-?}ltExBG)ZVkLwY}qk{Rf;MYsV@);dn46bWVY%KH8t~~D3Nw}>Gl?x zZjxen>3yMKP$*JJFWtE-$VHLVjj#3Y4#^aSr0YS>MX?mm8|!{ZrsRt6<$e8fCf6eN ze=sl6+w>k?p^qiZFX$V(O?T-R`V++xUI6tHU?YMOU@JmsLl~Xt!d~o?I0q5M5QcFG zlZas&GnkbyHRw2o6fzR1Es>tYDLjW|Jl~Hqcm?O+;8m9`-o!f|lYdbiT8|3XHllLA zu4`X86-8B%_hfUKj0fEvY9;e?kyF_?QKNYKR;&AfR3PF;C*64da#`(drC|g3#G+LI F|2N%_AiV$p delta 100 zcmZn(XmOBWU|?W$DortDU;r^WfEYvza8E20o2aMA$hI+HH$NlWW*&k2?30(u$p|t7 nWr09~8%Vf<6mBg1&ODi4C6I#=qMc!KJkR3I+@cSdfa)ayYNZnC diff --git a/scripts/0_data_creation_5k_nucleus.py b/scripts/0_data_creation_5k_nucleus.py index 26b94657..199568ac 100644 --- a/scripts/0_data_creation_5k_nucleus.py +++ b/scripts/0_data_creation_5k_nucleus.py @@ -49,8 +49,8 @@ # subsample the scRNAseq if needed -# sc.pp.subsample(scrnaseq, 0.1) -# scrnaseq.var_names_make_unique() +sc.pp.subsample(scrnaseq, 0.1) +scrnaseq.var_names_make_unique() # Calculate gene-celltype embeddings from reference data diff --git a/scripts/1_train_5k.py b/scripts/1_train_5k.py index f621a3be..751c5a69 100644 --- a/scripts/1_train_5k.py +++ b/scripts/1_train_5k.py @@ -17,8 +17,8 @@ -segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_cells") -models_dir = Path("./models/human_CRC_seg_cells") +segger_data_dir = Path("data_tidy/pyg_datasets/MNG_5k_sampled/output-XETG00078__0041719__Region_2__20241203__142052/") +models_dir = Path("./models/MNG_5k_sampled/output-XETG00078__0041719__Region_2__20241203__142052/") # Base directory to store Pytorch Lightning models # models_dir = Path('models') @@ -26,8 +26,8 @@ # Initialize the Lightning data module dm = SeggerDataModule( data_dir=segger_data_dir, - batch_size=2, - num_workers=2, + batch_size=3, + num_workers=3, ) dm.setup() @@ -43,6 +43,7 @@ model = Segger( + # is_token_based=is_token_based,s num_tx_tokens= num_tx_tokens, init_emb=8, hidden_channels=32, @@ -64,7 +65,7 @@ strategy="auto", precision="32", devices=4, # set higher number if more gpus are available - max_epochs=150, + max_epochs=250, default_root_dir=models_dir, logger=CSVLogger(models_dir), ) diff --git a/scripts/4_xenium_explorer.py b/scripts/4_xenium_explorer.py new file mode 100644 index 00000000..5429c1fd --- /dev/null +++ b/scripts/4_xenium_explorer.py @@ -0,0 +1,41 @@ + + +from segger.validation.xenium_explorer import seg2explorer +import pandas as pd + + +transcripts_file = 'data_tidy/benchmarks/human_CRC_seg_nuclei/human_CRC_seg_nuclei_0.4_False_4_15_5_3_20250521/segger_transcripts.parquet' +transcripts_df = pd.read_parquet(transcripts_file) +# transcripts_df = transcripts_df.iloc[:10000] +seg2explorer( + seg_df=transcripts_df, # this is your segger output + source_path="/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real", #raw xenium data + output_dir="data_tidy/explorer/human_CRC_real_nuclei", #where you wanna save your xneium explorer file, could be the same as raw + cells_filename="seg_cells1", #file names for cells.zarr + analysis_filename="seg_analysis1", #file names for analysis.zarr + xenium_filename="seg_experiment1.xenium", #xenium explorer file + analysis_df=None, + cell_id_columns="segger_cell_id", # segger cell id column in transcripts_df +) + + + +XENIUM_DATA_DIR = Path( #raw data dir + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" +) +transcripts_file = ( + XENIUM_DATA_DIR / "transcripts.parquet" +) + +SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") # preprocessed data dir + + +seg_tag = "human_CRC_seg_nuclei" +model_version = 0 +models_dir = Path("./models") / seg_tag #trained model dir + + +output_dir = Path( #output dir + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_nuclei" +) + diff --git a/scripts/batch_run_xenium.zip b/scripts/batch_run_xenium.zip new file mode 100644 index 0000000000000000000000000000000000000000..c996366e063fb700c648ed24d213d9d012453bc4 GIT binary patch literal 8921 zcmd6sbzD?k*Y}4UY5-}ZyIThoi9u9aa_Ev8x^qa8MjDlFln$j^5T#=XX+dcPk(7{! zx$fIn^zy~~{ycv?*XDEP9RAts{PtPvyVqVzSso3Y1n}2@uIl{j;a|U(0Q7)+#x5o> zLnl`|Lk}}MOIKSy4RsIz*vf4O8#=p?cUKQw01EmZ3IOoeG|GQXhXud_0M1R9b2cN@ z*>r}6ijbRk)wK9{4S${ar-@g2{x-3R5vhrV(LJNjiVD!r1kg}rf7BQXG&Vvi6h{Wa zz=H_NP}C>}28c4N%xgUQrQ2#RLkYyj&iYg^U3$g_sT1o8JXUDwv|Ei zZClv`9iS`^#3Bw;ScjgyAo3-DygvQ8*Uvmt!|=P+aBwm+wKQ@0ml$4W*k2+}&Q|67 z^~>r!gihH_-6r%|)>N=WtDrL-t~3+o#8a=vcm*@1B+`E&6Zk5QT4g<^0ndF^M`mY} zvQ%bm3@{of<2aC7mqxj^Le57rj9X_QqQp;gqkkix;BH%_B3t~srGqKgmiJzq#Ibwo z)O}*Sp7kSA@c1mhMZ=0lWj+eZ>xr>)@hlN_^LTUF7>}{ND!P@!Wds7XH=2fgl19n% z<>DU)_uUGg)W6n$sdrE^(%B&2g%o3}P%1MRUs_KpG}EAhU79-QfnSieWNu|-4}~hj zC6CyI4}CBTF!jS=H8%+?Ohww|gb+|)sK@GJPX}gUEMuA4ib3Vp`_PwJcr&Ebkp8Yf zl}rKIa)8<^Y&O}p?V#XpE^~&?ikL8Q*VL^Kz6TXNy3x~Vo=M>XGwv~d4~68$7qs~) zF4fDgxL&>P2MUbcG<4&m*gi3+xx-?fr0~SA>;9I1Md{rE-hQ9X{<7oQML1P~Hocai z%j-T-5RpJtZipV-(c&C+Kg183jiT4 zq*h=10nOO9sULCEZnA~ER|IiJU0M}lk<{Smd)`QdKKdjsLOCAnkxX>+LEVHM_Fbj4 z38Uj3N0EkW*TiM7Qg7mtG~ZM;!cWt5h}Jc$efrdmvztGM&`|xEaVd@@Esxc{(k{FY zBo+6mW^(B1oJ86596JwiC2&Y8pkYZMcRajDwM6@nB<89!&!le^hINXE^7RO7`CH93 zFUr3pD?>$dU3uL+Zt4f`Il7|JVDO+N!WW0$Y9M++vhh_m7P(>vuahV40#7q2KRu** z&dk@AXy)cT>sK(Qp<}fbE-XLdQ#l|bQAMV?5j8_Hojr1+?XLC6ZKtAvhBsxB;@1U& z@4JnUL7K1N5;+(caWl4q<1aZ8^YnjN1lKaIDa%EwzhHWJGsc#MSfFjqj@GD{UR!7b zq~An>KTp!Kxb62cS34`g(qpV)lhA4(0-7E_VW{`z7&+j%MI)uymKTK+n*xqGA%E+K z&dWR{_+ZwUfk}ftLT2^bgzG?9qV#L+yMz8qQ&Cw4(L|kgK2A?R=rNlan~mG2qfV=R zQN|dTpgsgyO>^T&+3Ga9>(!J5U9q(`j&~D3nu@TFz9T$uBU}QK#-A;`UU|KF)gc%@EtS&011-cBL)c{ z?(>BFm5*o0ypT5k0A$+S{jq^O!2|$UeuqpGCo^LgGec8j7vrC?sn1{s`+oh3P4zkZ zDKMe$xS^vO9=E4MLcSuM4*HX2tJ}ck*EYbFL05E^bQUmCL7X}I%hARG&2fUMDhp+e83|3RdQla(PJ7x2PG^iW2BlO+I?MICt3R|UNsEKvv`0K27Pxq^7zdfh% zoSq(0h|ofL$faS~lxpZ8>>e>E0!83;A?op`f?L0wAb!;Jjn_@>tV>Gayq5Kh36`;kSFU{A0X5X7#vbJM5|jtRuHtmj zj6S>-KzN8A5;U)9*+zgrMu{`W4krnsyt2}M)O%)tmAdJYSL7_7JCxTiYphF_g_+EFfTbU=bqx^-loaBN=M_!7-%xpxs;4)ASo}J_ zQeMeFnQU_a#0dm-`0d-g78ek|{sqW0Rm|FK=xS~`iAO)sgCg{0PK;9)eQU*7WMP}n z5WEF16hJ*RHn0kXrw-b^#8DRO7A@q~y5+|9XQ^BGu zH&x~rEUNonR{8~01P`@46N9fXh=g{hmq%F+zAK9%@ijS?ko_PmDvKMcXzdx0VO>06 zn9=WhLbQp_E;9A?8>TPj>Y6HDfVj60j$6mV2+ zVto0`Ybko=sq6>3eV_GpDc^VWY_w&XjHb~mz2h@fd;+t_nk9H^M6tCJ)bsHvyW7(~ z9_@uvH{;Dmbw_mExb?SviMpHoD~}#qm=o zxJU-Q!H(yIoG}O)nLikGK4y^&;{OAK{usXh01n+%{=p&FGsSejbUnd8U?Ccbi;}ymQ|KeOFP>j`bdF95Q#(4@ezio zY?^e#8z6u5W|Q#~r7_}49o})zmjzEdgnjBqt_3 zb|gUxK;D1x>_p^|M1;_1U(;bUM8u!V%pHz`y3)3uTY|EDUoV{O7G!YY&4LaiX7Sgjk*XDhc zd&zL25=x;}A{hb|VhGX;+D&#cr{p0ibAtJlHZ7r(S{O9HH?(@kSRvK5$HtS?T+~{< z0n_q-K^(``sSkUj8>3J*__h_+!N7vSNJ&QKKkGVj>@3;ZZKXP`oDYadMfaz^o_`yCQ|? zp2_Y5SZxvd#?|H_QUw`IPf9wyeW6UCY$B`g^8sd~TS;i7#~2)q97!33MKT2%sGT1P zHntJ)K}?k@><10FH|(K{U~#L!jB9GcPl&Whhgg!~rvp1nglF|h?~&W$S9aU7H-7WeWmny!MK{`| zys6PJAOoB>ovMQpbg3y)T&uudF~ohI3cW^yq;eHsT;7&hJ-t5hp@~mj2}RV z-j!8af+?mWM)?Q(O5vuNr|eCg^F!eTDc)ML)tn94dhe(JNAz7u7$&lDkV;-sr;;M& zHBK_ey>3!_L3`m_&Q*da;lr6p(49M)uRoO^QM#ZAK=*ej_MMF0;`?KJ^{!Uk;--T7 zzbZF~@WK3S>`0&$7vmgx!lZ>@A&BZq!N#uTuQq5Ri0yk;Nl@!CRd~EMkYSrYKh+UI z1##!2b#k9>T=As}mYGicBv@jhEOc#l@%GbAiiK-r1hV7W_-?cWV`=l5Sz`sjm$a?b zJYd2pGmBWsbSCrg_l)?-7FJBS*SWk62ARVo%BOcE%kpwP;7JRPZy*Yj}8DJiS>ib z=UIj%^M7~t`2}nL05-dv%Q;#w{)3$3-xBoCa*prUFJul<@3EWWCv05j!rCRIRQ@1h z8ED|auCtSXilVx6j{=>ilw&OWS)BU4{q+Xs-Oxa?-L}1rUU8A>yy2#bbO$cD9gQi_ z`pRmZy-pPyXFxQJwKX8Yh|9iiO`WaCRH#kNVlp%bV&K_jD>`ijdp0mfHWx$1gYop*|B>zkvzu+7Tk76!xy=^bRbk1r=+ zKw73yzEP&UII_PIEMjA;K;$QNIYiN{%pSCzOI&MJ`FxhMcsIDJ`%$-A$3mB2*3i-3 zSVZJI3pvPKE;m#CGOc}o!JMRU7Xp;rL;>%R$1WvZ>eSx4Ntdh{<>~w&>QfzBz}1#7 zWXs@=1GW1v+XR@SdiKmU)5^6z)znEWHgsQY4EbdKIB&aGk=%Lp8mOsU)m|hlJzy3Q z-DQ_R_*P-&bLmZ5-3swH(`2EI*JeB&C9E!a*2|1-1R&W=3HsSy@t zcwv_*9(Sc_eKt8m8NO7fCt7+$Hn!a_PJ?LRDW{9DKaE<2d{Sqg@{{}SwJP(}Oq zV8W&&&GP;>!FTLk)ph8D3^TR?sIS1!_6MmTVN4NK(wxkT9a=jr?gtFomHP0S>u-m9 z9cEJ~`(uP9rd@B(ElSM>Ovk#}bFoWe*zFiRt?S}nnalI!GD}q>QJ#+d-BEW6Ze`;5cPi6Y?>1?4k{ci;~qDvF{dnDrQqz+XPxt|NJVv*3`UnjdbYqBsS>AQMRa7EY2EO zR{_QMNV=Hu(y<`sTkm`s7&5*4Y zc?8F|<702@ccguCgZ@K0QQR3=C5Q&HK@8f!H~Y<8~=F{!G(v(`#uiKVB_ zii>Q)eveeOm-c|RlJ@=s@Z)G-(%52bs*bvCS_i5gmE-X+nZx`YvKOlySe67NzH2BF z?r&p~pT$XUJ9cOxdgyhn<^~S^ZtU_wYzp;mp%*3L=}Uo=t47VJc_V2ciBfeEnquK>=*}yiPcfUwrSvcI6M*TLR z!r$dY1{}P@;{MJet0Z#jxEWi0<8TChFC_n#rc#;+`Kk;p;aN6=J7-IbxJiuJbF)aH z7k?pOb~AiSg}nX%OUX7B6O@cfxzw1XELoE;bCo|`L2FkUQLPxVL%_;K^j2w&=#ie7 z1>u&cQeKBBjgSF|Xkh+XOOumKw>cX^oQ}3dajJXgK0`{JNAMe05^9n??Wvm}cz-IogcgF*V`}xH(c2SMVZ6*O2u_ zZLy@_lg}h&0CQ`Z;W`z%=;Z{PAbtF10{c6bomL->k|)IDCsYKUD%4zAo`_aUqH|~W z^|(=8wV4Ij_|%(hKqHY)=AnXi$XZ)6xOGCQzv%2!m{6W|;-52_rtAAb$wu#vIl=2; znj(&kU21s;l57l93XPmwA;;BPj427*u9~&!oDKD}(Mk(t2(kkUTIfqSipU4QEk8TF z_d%D^l~TWT^BaiQ)21E384b63m5e6fZqeqb_9;RtL~=8i6gj#C1)q-0&~T&yGaX`& z$&pBnf61s$be20IQS$?-=OKhd>VIFe`GrUShhRlPB|-cDw;#{M{WU&8khi~;#DD4; z{t+yKZ~nelH1|5!al zHUKUVi!^y&EYkD8kKBM=?2E ad.AnnData: + """ + Generates an AnnData object from a dataframe of segmented transcriptomics data. + + Parameters: + df (pd.DataFrame): The dataframe containing segmented transcriptomics data. + panel_df (Optional[pd.DataFrame]): The dataframe containing panel information. + min_transcripts (int): The minimum number of transcripts required for a cell to be included. + cell_id_col (str): The column name representing the cell ID in the input dataframe. + qv_threshold (float): The quality value threshold for filtering transcripts. + min_cell_area (float): The minimum cell area to include a cell. + max_cell_area (float): The maximum cell area to include a cell. + + Returns: + ad.AnnData: The generated AnnData object containing the transcriptomics data and metadata. + """ + # Filter out unassigned cells + df_filtered = df[df[cell_id_col].astype(str) != "UNASSIGNED"] + # Create pivot table for gene expression counts per cell + pivot_df = df_filtered.rename( + columns={cell_id_col: "cell", "feature_name": "gene"} + )[["cell", "gene"]].pivot_table( + index="cell", columns="gene", aggfunc="size", fill_value=0 + ) + pivot_df = pivot_df[pivot_df.sum(axis=1) >= min_transcripts] + # Summarize cell metrics + cell_summary = [] + for cell_id, cell_data in df_filtered.groupby(cell_id_col): + if len(cell_data) < min_transcripts: + continue + cell_convex_hull = ConvexHull( + cell_data[["x_location", "y_location"]], qhull_options="QJ" + ) + cell_area = cell_convex_hull.area + if cell_area < min_cell_area or cell_area > max_cell_area: + continue + cell_summary.append( + { + "cell": cell_id, + "cell_centroid_x": cell_data["x_location"].mean(), + "cell_centroid_y": cell_data["y_location"].mean(), + "cell_area": cell_area, + } + ) + cell_summary = pd.DataFrame(cell_summary).set_index("cell") + # Add genes from panel_df (if provided) to the pivot table + if panel_df is not None: + panel_df = panel_df.sort_values("gene") + genes = panel_df["gene"].values + for gene in genes: + if gene not in pivot_df: + pivot_df[gene] = 0 + pivot_df = pivot_df[genes.tolist()] + # Create var DataFrame + if panel_df is None: + var_df = pd.DataFrame( + [ + {"gene": gene, "feature_types": "Gene Expression", "genome": "Unknown"} + for gene in np.unique(pivot_df.columns.values) + ] + ).set_index("gene") + else: + var_df = panel_df[["gene", "ensembl"]].rename(columns={"ensembl": "gene_ids"}) + var_df["feature_types"] = "Gene Expression" + var_df["genome"] = "Unknown" + var_df = var_df.set_index("gene") + # Compute total assigned and unassigned transcript counts for each gene + assigned_counts = df_filtered.groupby("feature_name")["feature_name"].count() + unassigned_counts = ( + df[df[cell_id_col].astype(str) == "UNASSIGNED"] + .groupby("feature_name")["feature_name"] + .count() + ) + # var_df["total_assigned"] = var_df.index.map(assigned_counts).fillna(0).astype(int) + # var_df["total_unassigned"] = ( + # var_df.index.map(unassigned_counts).fillna(0).astype(int) + # ) + # Filter cells and create the AnnData object + cells = list(set(pivot_df.index) & set(cell_summary.index)) + pivot_df = pivot_df.loc[cells, :] + cell_summary = cell_summary.loc[cells, :] + adata = ad.AnnData(pivot_df.values) + adata.var = var_df + adata.obs["transcripts"] = pivot_df.sum(axis=1).values + adata.obs["unique_transcripts"] = (pivot_df > 0).sum(axis=1).values + adata.obs_names = pivot_df.index.values.tolist() + adata.obs = pd.merge( + adata.obs, + cell_summary.loc[adata.obs_names, :], + left_index=True, + right_index=True, + ) + return adata + +transcripts_df_filtered = pd.read_parquet(transcripts_save_path) +anndata_save_path = save_dir / "segger_adata.h5ad" +segger_adata = create_anndata( + transcripts_df_filtered, min_transcripts=5, cell_id_col='segger_cell_id' #**anndata_kwargs +) # Compute for AnnData +segger_adata.write(anndata_save_path) +if verbose: + elapsed_time = time() - step_start_time + print(f"Saved anndata object in {elapsed_time:.2f} seconds.") + +if verbose: +elapsed_time = time() - step_start_time +print(f"Results saved in {elapsed_time:.2f} seconds at {save_dir}.") + +# Step 6: Save segmentation parameters as a JSON log +log_data = { + "seg_tag": seg_tag, + "score_cut": score_cut, + "use_cc": use_cc, + "receptive_field": receptive_field, + "knn_method": knn_method, + "save_transcripts": save_transcripts, + "save_anndata": save_anndata, + "save_cell_masks": save_cell_masks, + "timestamp": datetime.now().isoformat(), +} + +log_path = save_dir / "segmentation_log.json" +with open(log_path, "w") as log_file: + json.dump(log_data, log_file, indent=4) + +# Step 7: Garbage collection and memory cleanup +torch.cuda.empty_cache() +gc.collect() + +# Total time taken for the segmentation process +if verbose: + total_time = time() - start_time + print(f"Total segmentation process completed in {total_time:.2f} seconds.") diff --git a/scripts/predict_5k_yiheng.py b/scripts/predict_5k_yiheng.py new file mode 100644 index 00000000..ff0882dc --- /dev/null +++ b/scripts/predict_5k_yiheng.py @@ -0,0 +1,73 @@ +from segger.training.segger_data_module import SeggerDataModule +from segger.prediction.predict_parquet import segment, load_model +from pathlib import Path +from matplotlib import pyplot as plt +import seaborn as sns +import scanpy as sc +import os +import dask.dataframe as dd +import pandas as pd +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["CUPY_CACHE_DIR"] = "./.cupy" +import cupy as cp +from dask.distributed import Client, LocalCluster +from dask_cuda import LocalCUDACluster +import dask.dataframe as dd + + + +seg_tag = "output-XETG00078__0041722__Region_1__20241203__142052" +model_version = 4 +models_dir = Path("./models/MNG_5k_sampled/") / seg_tag + + +seg = "output-XETG00078__0041719__Region_2__20241203__142052" + +XENIUM_DATA_DIR = Path( + "/omics/odcf/analysis/OE0606_projects_temp/xenium_projects/20241209_Xenium5k_CNSL_BrM/20241209_Xenium5k_CNSL_BrM" +) / seg +SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/MNG_5k_sampled") / seg + + +benchmarks_dir = Path( + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks" +) / seg +transcripts_file = ( + XENIUM_DATA_DIR / "transcripts.parquet" +) +# Initialize the Lightning data module +dm = SeggerDataModule( + data_dir=SEGGER_DATA_DIR, + batch_size=2, + num_workers=2, +) + +dm.setup() + + +# Load in latest checkpoint +model_path = models_dir / "lightning_logs" / f"version_{model_version}" +model = load_model(model_path / "checkpoints") + +receptive_field = {"k_bd": 4, "dist_bd": 7.5, "k_tx": 5, "dist_tx": 3} + +segment( + model, + dm, + save_dir=benchmarks_dir, + seg_tag=seg, + transcript_file=transcripts_file, + # file_format='anndata', + receptive_field=receptive_field, + min_transcripts=5, + score_cut=0.75, + # max_transcripts=1500, + cell_id_col="segger_cell_id", + use_cc=False, + knn_method="kd_tree", + verbose=True, + gpu_ids=["0"], + # client=client +) diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index 249bb439..ec140e06 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -18,7 +18,7 @@ -seg_tag = "human_CRC_seg_cells" +seg_tag = "human_CRC_seg_nuclei" model_version = 0 @@ -26,11 +26,11 @@ XENIUM_DATA_DIR = Path( "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" ) -SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_cells") +SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") models_dir = Path("./models") / seg_tag benchmarks_dir = Path( - "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_cells" + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_nuclei" ) transcripts_file = ( "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real/transcripts.parquet" diff --git a/scripts/predict_project24.py b/scripts/predict_project24.py new file mode 100644 index 00000000..67ef7042 --- /dev/null +++ b/scripts/predict_project24.py @@ -0,0 +1,73 @@ +from segger.training.segger_data_module import SeggerDataModule +from segger.prediction.predict_parquet import segment, load_model +from pathlib import Path +from matplotlib import pyplot as plt +import seaborn as sns +import scanpy as sc +import os +import dask.dataframe as dd +import pandas as pd +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["CUPY_CACHE_DIR"] = "./.cupy" +import cupy as cp +from dask.distributed import Client, LocalCluster +from dask_cuda import LocalCUDACluster +import dask.dataframe as dd + + + +seg_tag = "output-XETG00423__0052506__mng_04_TMA__20250310__160549" +model_version = 0 +models_dir = Path("./models/project24_MNG_pqdm") / seg_tag + + +for s in os.listdir('data_tidy/pyg_datasets/project24_MNG_final/'): + if s == seg_tag: + + XENIUM_DATA_DIR = Path( + "data_tidy/pyg_datasets/project24_MNG_final") / s + + + + benchmarks_dir = Path( + "data_tidy/benchmarks/project24_MNG_final") / s + + transcripts_file = Path( + "/omics/odcf/analysis/OE0606_projects/oncolgy_data_exchange/domenico_temp/xenium/xenium_output_files") / s / "transcripts.parquet" + + # Initialize the Lightning data module + dm = SeggerDataModule( + data_dir=XENIUM_DATA_DIR, + batch_size=1, + num_workers=1, + ) + + dm.setup() + + + # Load in latest checkpoint + model_path = models_dir / "lightning_logs" / f"version_{model_version}" + model = load_model(model_path / "checkpoints") + + receptive_field = {"k_bd": 4, "dist_bd": 7.5, "k_tx": 5, "dist_tx": 3} + + segment( + model, + dm, + save_dir=benchmarks_dir, + seg_tag=s, + transcript_file=transcripts_file, + # file_format='anndata', + receptive_field=receptive_field, + min_transcripts=5, + score_cut=0.75, + # max_transcripts=1500, + cell_id_col="segger_cell_id", + use_cc=False, + knn_method="kd_tree", + verbose=True, + gpu_ids=["0"], + # client=client + ) diff --git a/scripts/train_MNG_5k.sh b/scripts/train_MNG_5k.sh index a555e55a..6c828c11 100644 --- a/scripts/train_MNG_5k.sh +++ b/scripts/train_MNG_5k.sh @@ -1,5 +1,9 @@ -DATA_ROOT="data_tidy/pyg_datasets/MNG_5k_sampled" +DATA_ROOT="data_tidy/pyg_datasets" +MODELS_ROOT="models" +PROJECT_NAME="MNG_5k" + + for folder in "$DATA_ROOT"/*; do if [ -d "$folder" ]; then @@ -8,6 +12,6 @@ for folder in "$DATA_ROOT"/*; do -gpu num=4:j_exclusive=yes:gmem=20.7G \ -R "rusage[mem=100GB]" \ -q gpu-debian \ - python /dkfz/cluster/gpu/data/OE0606/elihei/segger_dev/scripts/train_model.py --data_dir "$folder" + python /dkfz/cluster/gpu/data/OE0606/elihei/segger_dev/scripts/train_model.py --data_dir "$DATA_ROOT/$PROJECT_NAME/$folder" --model_dir "$MODELS_ROOT/$PROJECT_NAME/$folder" fi done \ No newline at end of file diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index f64b39dc..063f8f93 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -17,8 +17,8 @@ -segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_cells") -models_dir = Path("./models/human_CRC_seg_cells") +segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_exmax_weights") +models_dir = Path("./models/human_CRC_seg_exmax_weights") # Base directory to store Pytorch Lightning models # models_dir = Path('models') @@ -35,28 +35,28 @@ # is_token_based = True # num_tx_tokens = 500 -# If you use custom gene embeddings, use the following two lines instead: -is_token_based = False -num_tx_tokens = ( - dm.train[0].x_dict["tx"].shape[1] -) # Set the number of tokens to the number of genes +# # If you use custom gene embeddings, use the following two lines instead: +# is_token_based = False +# num_tx_tokens = ( +# dm.train[0].x_dict["tx"].shape[1] +# ) # Set the number of tokens to the number of genes model = Segger( # is_token_based=is_token_based, - num_tx_tokens= num_tx_tokens, + num_tx_tokens= 500, init_emb=8, hidden_channels=64, out_channels=16, heads=4, - num_mid_layers=3, + num_mid_layers=2, ) model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) # Wrap the model in LitSegger -ls = LitSegger(model=model) +ls = LitSegger(model=model, align_loss=True, align_lambda=1) # # Initialize the Lightning model # ls = LitSegger( @@ -73,9 +73,9 @@ trainer = Trainer( accelerator="gpu", strategy="auto", - precision="16-mixed", - devices=4, # set higher number if more gpus are available - max_epochs=150, + precision="32", + devices=1, # set higher number if more gpus are available + max_epochs=500, default_root_dir=models_dir, logger=CSVLogger(models_dir), ) diff --git a/src/segger/data/parquet/_utils.py b/src/segger/data/parquet/_utils.py index a3e6de70..51cfc49a 100644 --- a/src/segger/data/parquet/_utils.py +++ b/src/segger/data/parquet/_utils.py @@ -5,13 +5,16 @@ from pyarrow import parquet as pq import numpy as np import scipy as sp -from typing import Optional, List +from typing import Optional, List, Dict, Tuple import sys from types import SimpleNamespace from pathlib import Path import yaml import os import pyarrow as pa +import scanpy as sc +import anndata as ad +from itertools import combinations def get_xy_extents( @@ -509,3 +512,134 @@ def ensure_transcript_ids( write_statistics=True, # Ensure statistics are written compression="snappy", # Use snappy compression for better performance ) + + + +def find_markers( + adata: ad.AnnData, + cell_type_column: str, + pos_percentile: float = 5, + neg_percentile: float = 10, + percentage: float = 50, +) -> Dict[str, Dict[str, List[str]]]: + """ + Identify positive and negative marker genes for each cell type in an AnnData object. + + Positive markers are top-ranked genes that are expressed in at least + `percentage` percent of cells in the given cell type. + + Parameters + ---------- + adata : AnnData + Annotated data object containing gene expression data and cell type annotations. + cell_type_column : str + Name of the column in `adata.obs` specifying cell type identity for each cell. + pos_percentile : float, optional (default: 5) + Percentile threshold for selecting top highly expressed genes as positive markers. + neg_percentile : float, optional (default: 10) + Percentile threshold for selecting lowest expressed genes as negative markers. + percentage : float, optional (default: 50) + Minimum percent of cells (0-100) in a cell type expressing a gene for it to be a marker. + + Returns + ------- + markers : dict + Dictionary mapping cell type names to: + { + 'positive': [list of positive marker gene names], + 'negative': [list of negative marker gene names] + } + """ + markers = {} + sc.tl.rank_genes_groups(adata, groupby=cell_type_column) + genes = np.array(adata.var_names) + n_genes = adata.shape[1] + + # Work with a dense matrix for expression fraction calculation + # (convert sparse to dense if needed) + if not isinstance(adata.X, np.ndarray): + expr_matrix = adata.X.toarray() + else: + expr_matrix = adata.X + + for cell_type in adata.obs[cell_type_column].unique(): + mask = (adata.obs[cell_type_column] == cell_type).values + gene_names = np.array(adata.uns['rank_genes_groups']['names'][cell_type]) + + n_pos = max(1, int(n_genes * pos_percentile // 100)) + n_neg = max(1, int(n_genes * neg_percentile // 100)) + + # Calculate percent of cells in this cell type expressing each gene + expr_frac = (expr_matrix[mask] > 0).mean(axis=0) * 100 # as percent + + # Filter positive markers by expression fraction + pos_indices = [] + for idx in range(n_pos): + gene = gene_names[idx] + gene_idx = np.where(genes == gene)[0][0] + if expr_frac[gene_idx] >= percentage: + pos_indices.append(idx) + positive_markers = list(gene_names[pos_indices]) + + # Negative markers are the lowest-ranked + negative_markers = list(gene_names[-n_neg:]) + + markers[cell_type] = { + "positive": positive_markers, + "negative": negative_markers + } + return markers + +def find_mutually_exclusive_genes( + adata: ad.AnnData, markers: Dict[str, Dict[str, List[str]]], cell_type_column: str +) -> List[Tuple[str, str]]: + """Identify mutually exclusive genes based on expression criteria. + + Args: + - adata: AnnData + Annotated data object containing gene expression data. + - markers: dict + Dictionary where keys are cell types and values are dictionaries containing: + 'positive': list of top x% highly expressed genes + 'negative': list of top x% lowly expressed genes. + - cell_type_column: str + Column name in `adata.obs` that specifies cell types. + + Returns: + - exclusive_pairs: list + List of mutually exclusive gene pairs. + """ + exclusive_genes = {} + all_exclusive = [] + gene_expression = adata.to_df() + for cell_type, marker_sets in markers.items(): + positive_markers = marker_sets["positive"] + exclusive_genes[cell_type] = [] + for gene in positive_markers: + gene_expr = adata[:, gene].X + cell_type_mask = adata.obs[cell_type_column] == cell_type + non_cell_type_mask = ~cell_type_mask + if (gene_expr[cell_type_mask] > 0).mean() > 0.2 and ( + gene_expr[non_cell_type_mask] > 0 + ).mean() < 0.05: + exclusive_genes[cell_type].append(gene) + all_exclusive.append(gene) + unique_genes = list( + { + gene + for i in exclusive_genes.keys() + for gene in exclusive_genes[i] + if gene in all_exclusive + } + ) + filtered_exclusive_genes = { + i: [gene for gene in exclusive_genes[i] if gene in unique_genes] + for i in exclusive_genes.keys() + } + mutually_exclusive_gene_pairs = [ + tuple(sorted((gene1, gene2))) + for key1, key2 in combinations(filtered_exclusive_genes.keys(), 2) + for gene1 in filtered_exclusive_genes[key1] + for gene2 in filtered_exclusive_genes[key2] + ] + return set(mutually_exclusive_gene_pairs) \ No newline at end of file diff --git a/src/segger/data/parquet/sample.py b/src/segger/data/parquet/sample.py index 0729b9f4..0e5e4fd2 100644 --- a/src/segger/data/parquet/sample.py +++ b/src/segger/data/parquet/sample.py @@ -366,6 +366,8 @@ def save( dist_bd: float = 15.0, k_tx: int = 3, dist_tx: float = 5.0, + k_tx_ex: int = 100, + dist_tx_ex: float = 20, tile_size: Optional[int] = None, tile_width: Optional[float] = None, tile_height: Optional[float] = None, @@ -373,6 +375,7 @@ def save( frac: float = 1.0, val_prob: float = 0.1, test_prob: float = 0.2, + mutually_exclusive_genes: Optional[List] = None, ): """ Saves the tiles of the sample as PyTorch geometric datasets. See @@ -455,7 +458,10 @@ def func(region): dist_bd=dist_bd, k_tx=k_tx, dist_tx=dist_tx, + k_tx_ex=k_tx_ex, + dist_tx_ex=dist_tx_ex, neg_sampling_ratio=neg_sampling_ratio, + mutually_exclusive_genes = mutually_exclusive_genes ) if pyg_data is not None: if pyg_data["tx", "belongs", "bd"].edge_index.numel() == 0: @@ -1179,15 +1185,18 @@ def get_boundary_props( def to_pyg_dataset( self, # train: bool, - neg_sampling_ratio: float = 5, + neg_sampling_ratio: float = 10, k_bd: int = 3, dist_bd: float = 15, k_tx: int = 3, dist_tx: float = 5, + k_tx_ex: int = 100, + dist_tx_ex: float = 20, area: bool = True, convexity: bool = True, elongation: bool = True, circularity: bool = True, + mutually_exclusive_genes: Optional[List] = None, ) -> HeteroData: """ Converts the sample data to a PyG HeteroData object. @@ -1214,6 +1223,8 @@ def to_pyg_dataset( ) pyg_data["tx"].x = self.get_transcript_props() + + # Set up Transcript-Transcript neighbor edges nbrs_edge_idx = self.get_kdtree_edge_index( self.transcripts[self.settings.transcripts.xyz], @@ -1228,6 +1239,27 @@ def to_pyg_dataset( pyg_data["tx", "neighbors", "tx"].edge_index = nbrs_edge_idx + + if mutually_exclusive_genes is not None: + nbrs_edge_idx = self.get_kdtree_edge_index( + self.transcripts[self.settings.transcripts.xyz], + self.transcripts[self.settings.transcripts.xyz], + k=k_tx_ex, + max_distance=dist_tx_ex, + ) + gene_ids = self.transcripts[self.settings.transcripts.label].tolist() + src_gene = [gene_ids[idx] for idx in nbrs_edge_idx[0].tolist()] + dst_gene = [gene_ids[idx] for idx in nbrs_edge_idx[1].tolist()] + + mask = [ + tuple(sorted((a, b))) in mutually_exclusive_genes + for a, b in zip(src_gene, dst_gene) + ] + mask = torch.tensor(mask) + + pyg_data["tx", "excludes", "tx"].edge_index = nbrs_edge_idx[:, mask] + + # Set up Boundary nodes # Check if boundaries have geometries geometry_column = getattr(self.settings.boundaries, 'geometry', None) @@ -1305,23 +1337,38 @@ def to_pyg_dataset( return pyg_data # If there are tx-bd edges, add negative edges for training - transform = RandomLinkSplit( - num_val=0, - num_test=0, - is_undirected=True, - edge_types=[edge_type], - neg_sampling_ratio=neg_sampling_ratio, - ) - pyg_data, _, _ = transform(pyg_data) + pos_edges = blng_edge_idx # shape (2, num_pos) + num_pos = pos_edges.shape[1] - # Refilter negative edges to include only transcripts in the - # original positive edges (still need a memory-efficient solution) - edges = pyg_data[edge_type] - mask = edges.edge_label_index[0].unsqueeze(1) == edges.edge_index[0].unsqueeze( - 0 - ) + # Negative edges (tx-neighbors-bd) - EXCLUDE positives + neg_candidates = nbrs_edge_idx # shape (2, num_candidates) + + # --- Fast Negative Filtering (PyTorch-only) --- + # Reshape edges for broadcasting: (2, num_pos) vs (2, num_candidates, 1) + pos_expanded = pos_edges.unsqueeze(2) # shape (2, num_pos, 1) + neg_expanded = neg_candidates.unsqueeze(1) # shape (2, 1, num_candidates) + + # Compare all edges in one go (broadcasting) + matches = (pos_expanded == neg_expanded).all(dim=0) # shape (num_pos, num_candidates) + is_negative = ~matches.any(dim=0) # shape (num_candidates,) + + # Filter negatives + neg_edges = neg_candidates[:, is_negative] # shape (2, num_filtered_neg) + num_neg = neg_edges.shape[1] + + # --- Combine and label --- + edge_label_index = torch.cat([neg_edges, pos_edges], dim=1) + edge_label = torch.cat([ + torch.zeros(num_neg, dtype=torch.float), + torch.ones(num_pos, dtype=torch.float) + ]) + + mask = edge_label_index[0].unsqueeze(1) == blng_edge_idx[0].unsqueeze(0) mask = torch.nonzero(torch.any(mask, 1)).squeeze() - edges.edge_label_index = edges.edge_label_index[:, mask] - edges.edge_label = edges.edge_label[mask] + edge_label_index = edge_label_index[:, mask] + edge_label = edge_label[mask] + + pyg_data[edge_type].edge_label_index = edge_label_index + pyg_data[edge_type].edge_label = edge_label return pyg_data diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index 8214ea41..33ec5521 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -583,7 +583,7 @@ def segment( elapsed_time = time() - step_start_time print(f"Batch processing completed in {elapsed_time:.2f} seconds.") - seg_final_dd = pd.read_parquet(output_ddf_save_path) + # seg_final_dd = pd.read_parquet(output_ddf_save_path) step_start_time = time() if verbose: diff --git a/src/segger/training/train.py b/src/segger/training/train.py index e0c5de03..0ff9d549 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -6,7 +6,6 @@ from lightning import LightningModule from segger.models.segger_model import * - class LitSegger(LightningModule): """ LitSegger is a PyTorch Lightning module for training and validating the @@ -17,6 +16,8 @@ def __init__( self, model: Segger, learning_rate: float = 1e-3, + align_loss: bool = False, + align_lambda: float = 0 ): """ Initialize the Segger training module. @@ -35,9 +36,12 @@ def __init__( # Other setup self.learning_rate = learning_rate + self.align_loss = align_loss + self.align_lambda = align_lambda self.criterion = torch.nn.BCEWithLogitsLoss() self.validation_step_outputs = [] + def forward(self, batch) -> torch.Tensor: """ Forward pass for the batch of data. @@ -74,18 +78,21 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ # Forward pass to get the logits z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) - - # Get edge labels and logits - edge_label_index = batch["tx", "belongs", "bd"].edge_label_index - out_values = output[edge_label_index[0], edge_label_index[1]] + edge_index = batch["tx", "belongs", "bd"].edge_label_index + # Compute edge scores via message passing + out_values = (z["tx"][edge_index[0]] * z["bd"][edge_index[1]]).sum(-1) edge_label = batch["tx", "belongs", "bd"].edge_label - - # Compute binary cross-entropy loss with logits (no sigmoid here) loss = self.criterion(out_values, edge_label) - # Log the training loss self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs) + if self.align_loss: + if self.align_loss: + edge_index = batch["tx", "excludes", "tx"].edge_index + out_values = (z["tx"][edge_index[0]] * z["tx"][edge_index[1]]).sum(-1) + targets = torch.zeros_like(out_values) + align_loss = self.criterion(out_values, targets) + self.log("align_loss", align_loss, prog_bar=True, batch_size=batch.num_graphs) + loss = loss + align_loss * self.align_lambda return loss def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: @@ -106,14 +113,11 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ # Forward pass to get the logits z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) - - # Get edge labels and logits - edge_label_index = batch["tx", "belongs", "bd"].edge_label_index - out_values = output[edge_label_index[0], edge_label_index[1]] + edge_index = batch["tx", "belongs", "bd"].edge_label_index + # Compute edge scores via message passing + out_values = (z["tx"][edge_index[0]] * z["bd"][edge_index[1]]).sum(-1) + edge_label = batch["tx", "belongs", "bd"].edge_label - - # Compute binary cross-entropy loss with logits (no sigmoid here) loss = self.criterion(out_values, edge_label) # Apply sigmoid to logits for AUROC and F1 metrics @@ -149,7 +153,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-4) + optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4) return optimizer def on_before_optimizer_step(self, optimizer): diff --git a/src/segger/validation/xenium_explorer.py b/src/segger/validation/xenium_explorer.py index d6489bd7..219cc4d0 100644 --- a/src/segger/validation/xenium_explorer.py +++ b/src/segger/validation/xenium_explorer.py @@ -8,6 +8,217 @@ import matplotlib.pyplot as plt from tqdm import tqdm from typing import Dict, Any, Optional, List, Tuple +from segger.prediction.boundary import generate_boundary +from shapely import Polygon +import zarr + + + +def get_flatten_version(polygon_vertices: List[List[Tuple[float, float]]], max_value: int = 16) -> np.ndarray: + """Standardize list of polygon vertices to a fixed shape. + + Args: + polygon_vertices (List[List[Tuple[float, float]]]): List of polygon coordinate lists. + max_value (int): Max number of coordinates per polygon. + + Returns: + np.ndarray: Padded or truncated list of polygon vertices. + """ + flattened = [] + for vertices in polygon_vertices: + if isinstance(vertices, np.ndarray): + # Handle numpy array case + if len(vertices) >= max_value: + return vertices[:max_value] + padding = np.zeros((max_value - len(vertices), vertices.shape[1])) + return np.concatenate([vertices, padding]) + if len(vertices) > max_value: + flattened.append(vertices[:max_value]) + else: + flattened.append(vertices + [(0.0, 0.0)] * (max_value - len(vertices))) + return np.array(flattened, dtype=np.float32) + + +def seg2explorer( + seg_df: pd.DataFrame, + source_path: str, + output_dir: str, + cells_filename: str = "seg_cells", + analysis_filename: str = "seg_analysis", + xenium_filename: str = "seg_experiment.xenium", + analysis_df: Optional[pd.DataFrame] = None, + draw: bool = False, + cell_id_columns: str = "seg_cell_id", + area_low: float = 10, + area_high: float = 100, +) -> None: + """Convert segmentation results into a Xenium Explorer-compatible Zarr dataset. + + Args: + seg_df (pd.DataFrame): Segmented transcript dataframe. + source_path (str): Path to the original Zarr store. + output_dir (str): Output directory to save new Zarr and Xenium files. + cells_filename (str): Filename prefix for cell Zarr file. + analysis_filename (str): Filename prefix for cell group Zarr file. + xenium_filename (str): Output experiment filename for Xenium. + analysis_df (Optional[pd.DataFrame]): Optional dataframe with cluster annotations. + draw (bool): Whether to draw polygons (not used currently). + cell_id_columns (str): Column containing cell IDs. + area_low (float): Minimum area threshold to include cells. + area_high (float): Maximum area threshold to include cells. + """ + source_path = Path(source_path) + storage = Path(output_dir) + + cell_id2old_id: Dict[int, Any] = {} + cell_id: List[int] = [] + cell_summary: List[Dict[str, Any]] = [] + polygon_num_vertices: List[List[int]] = [[], []] + polygon_vertices: List[List[Any]] = [[], []] + seg_mask_value: List[int] = [] + + grouped_by = seg_df.groupby(cell_id_columns) + + for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm( + enumerate(grouped_by), total=len(grouped_by) + ): + if len(seg_cell) < 5: + continue + + cell_convex_hull = generate_boundary(seg_cell) + if cell_convex_hull is None or not isinstance(cell_convex_hull, Polygon): + continue + + if not (area_low <= cell_convex_hull.area <= area_high): + continue + + uint_cell_id = cell_incremental_id + 1 + cell_id2old_id[uint_cell_id] = seg_cell_id + + seg_nucleous = seg_cell[seg_cell["overlaps_nucleus"] == 1] + nucleus_convex_hull = None + # if len(seg_nucleous) >= 3: + # try: + # nucleus_convex_hull = ConvexHull(seg_nucleous[["x_location", "y_location"]]) + # except Exception: + # pass + + cell_id.append(uint_cell_id) + cell_summary.append( + { + "cell_centroid_x": seg_cell["x_location"].mean(), + "cell_centroid_y": seg_cell["y_location"].mean(), + "cell_area": cell_convex_hull.area, + "nucleus_centroid_x": seg_cell["x_location"].mean(), + "nucleus_centroid_y": seg_cell["y_location"].mean(), + "nucleus_area": cell_convex_hull.area, + "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3, + } + ) + polygon_num_vertices[0].append(len(cell_convex_hull.exterior.coords)) + # polygon_num_vertices[1].append( + # len(nucleus_convex_hull.vertices) if nucleus_convex_hull else 0 + # ) + polygon_vertices[0].append(list(cell_convex_hull.exterior.coords)) + # polygon_vertices[1].append( + # seg_nucleous[["x_location", "y_location"]].values[ + # nucleus_convex_hull.vertices + # ] + # if nucleus_convex_hull else np.array([[], []]).T + # ) + seg_mask_value.append(uint_cell_id) + + print(polygon_vertices[0][0]) + + cell_polygon_vertices = get_flatten_version(polygon_vertices[0], max_value=21) + # nucl_polygon_vertices = get_flatten_version(polygon_vertices[1], max_value=21) + + print(cell_polygon_vertices) + print(cell_polygon_vertices.shape) + # print(nucl_polygon_vertices) + # print(nucl_polygon_vertices.shape) + cells = { + "cell_id": np.array( + [np.array(cell_id), np.ones(len(cell_id))], dtype=np.uint32 + ).T, + "cell_summary": pd.DataFrame(cell_summary).values.astype(np.float64), + "polygon_num_vertices": np.array( + [ + # [min(x + 1, x + 1) for x in polygon_num_vertices[1]], + [min(x + 1, x + 1) for x in polygon_num_vertices[0]], + ], + dtype=np.int32, + ), + "polygon_vertices": np.array( + [cell_polygon_vertices], dtype=np.float32 + ), + "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), + } + + print(cells) + + existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r") + new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") + new_store["cell_id"] = cells["cell_id"] + new_store["polygon_num_vertices"] = cells["polygon_num_vertices"] + new_store["polygon_vertices"] = cells["polygon_vertices"] + new_store["seg_mask_value"] = cells["seg_mask_value"] + new_store.attrs.update(existing_store.attrs) + new_store.attrs["number_cells"] = len(cells["cell_id"]) + new_store.store.close() + + if analysis_df is None: + analysis_df = pd.DataFrame( + [cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns] + ) + analysis_df["default"] = "seg" + + zarr_df = pd.DataFrame( + [cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns] + ) + clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_columns) + clusters_names = [col for col in analysis_df.columns if col != cell_id_columns] + + clusters_dict = { + cluster: { + label: idx + 1 + for idx, label in enumerate( + sorted(np.unique(clustering_df[cluster].dropna())) + ) + } + for cluster in clusters_names + } + + new_zarr = zarr.open(storage / f"{analysis_filename}.zarr.zip", mode="w") + new_zarr.create_group("/cell_groups") + for i, cluster in enumerate(clusters_names): + new_zarr["cell_groups"].create_group(str(i)) + group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] + indices, indptr = get_indices_indptr(np.array(group_values)) + new_zarr["cell_groups"][str(i)]["indices"] = indices + new_zarr["cell_groups"][str(i)]["indptr"] = indptr + + new_zarr["cell_groups"].attrs.update( + { + "major_version": 1, + "minor_version": 0, + "number_groupings": len(clusters_names), + "grouping_names": clusters_names, + "group_names": [ + sorted(clusters_dict[cluster], key=clusters_dict[cluster].get) + for cluster in clusters_names + ], + } + ) + new_zarr.store.close() + + generate_experiment_file( + template_path=source_path / "experiment.xenium", + output_path=storage / xenium_filename, + cells_name=cells_filename, + analysis_name=analysis_filename, + ) + def str_to_uint32(cell_id_str: str) -> Tuple[int, int]: @@ -219,220 +430,6 @@ def get_median_expression_table(adata, column: str = "leiden") -> pd.DataFrame: return cluster_expression_df.T.style.background_gradient(cmap="Greens") -def seg2explorer( - seg_df: pd.DataFrame, - source_path: str, - output_dir: str, - cells_filename: str = "seg_cells", - analysis_filename: str = "seg_analysis", - xenium_filename: str = "seg_experiment.xenium", - analysis_df: Optional[pd.DataFrame] = None, - draw: bool = False, - cell_id_columns: str = "seg_cell_id", - area_low: float = 10, - area_high: float = 100, -) -> None: - """Convert seg output to a format compatible with Xenium explorer. - - Args: - seg_df (pd.DataFrame): The seg DataFrame. - source_path (str): The source path. - output_dir (str): The output directory. - cells_filename (str): The filename for cells. - analysis_filename (str): The filename for analysis. - xenium_filename (str): The filename for Xenium. - analysis_df (Optional[pd.DataFrame]): The analysis DataFrame. - draw (bool): Whether to draw the plots. - cell_id_columns (str): The cell ID columns. - area_low (float): The lower area threshold. - area_high (float): The upper area threshold. - """ - import zarr - import json - - source_path = Path(source_path) - storage = Path(output_dir) - - cell_id2old_id = {} - cell_id = [] - cell_summary = [] - polygon_num_vertices = [[], []] - polygon_vertices = [[], []] - seg_mask_value = [] - tma_id = [] - - grouped_by = seg_df.groupby(cell_id_columns) - for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm( - enumerate(grouped_by), total=len(grouped_by) - ): - if len(seg_cell) < 5: - continue - - cell_convex_hull = ConvexHull(seg_cell[["x_location", "y_location"]]) - if cell_convex_hull.area > area_high: - continue - if cell_convex_hull.area < area_low: - continue - - uint_cell_id = cell_incremental_id + 1 - cell_id2old_id[uint_cell_id] = seg_cell_id - - seg_nucleous = seg_cell[seg_cell["overlaps_nucleus"] == 1] - if len(seg_nucleous) >= 3: - nucleus_convex_hull = ConvexHull(seg_nucleous[["x_location", "y_location"]]) - - cell_id.append(uint_cell_id) - cell_summary.append( - { - "cell_centroid_x": seg_cell["x_location"].mean(), - "cell_centroid_y": seg_cell["y_location"].mean(), - "cell_area": cell_convex_hull.area, - "nucleus_centroid_x": seg_cell["x_location"].mean(), - "nucleus_centroid_y": seg_cell["y_location"].mean(), - "nucleus_area": cell_convex_hull.area, - "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3, - } - ) - - polygon_num_vertices[0].append(len(cell_convex_hull.vertices)) - polygon_num_vertices[1].append( - len(nucleus_convex_hull.vertices) if len(seg_nucleous) >= 3 else 0 - ) - polygon_vertices[0].append( - seg_cell[["x_location", "y_location"]].values[cell_convex_hull.vertices] - ) - polygon_vertices[1].append( - seg_nucleous[["x_location", "y_location"]].values[ - nucleus_convex_hull.vertices - ] - if len(seg_nucleous) >= 3 - else np.array([[], []]).T - ) - seg_mask_value.append(cell_incremental_id + 1) - - cell_polygon_vertices = get_flatten_version(polygon_vertices[0], max_value=21) - nucl_polygon_vertices = get_flatten_version(polygon_vertices[1], max_value=21) - - cells = { - "cell_id": np.array( - [np.array(cell_id), np.ones(len(cell_id))], dtype=np.uint32 - ).T, - "cell_summary": pd.DataFrame(cell_summary).values.astype(np.float64), - "polygon_num_vertices": np.array( - [ - [min(x + 1, x + 1) for x in polygon_num_vertices[1]], - [min(x + 1, x + 1) for x in polygon_num_vertices[0]], - ], - dtype=np.int32, - ), - "polygon_vertices": np.array( - [nucl_polygon_vertices, cell_polygon_vertices] - ).astype(np.float32), - "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), - } - - existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r") - new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") - - new_store["cell_id"] = cells["cell_id"] - new_store["polygon_num_vertices"] = cells["polygon_num_vertices"] - new_store["polygon_vertices"] = cells["polygon_vertices"] - new_store["seg_mask_value"] = cells["seg_mask_value"] - - new_store.attrs.update(existing_store.attrs) - new_store.attrs["number_cells"] = len(cells["cell_id"]) - new_store.store.close() - - if analysis_df is None: - analysis_df = pd.DataFrame( - [cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns] - ) - analysis_df["default"] = "seg" - - zarr_df = pd.DataFrame( - [cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns] - ) - clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_columns) - - clusters_names = [i for i in analysis_df.columns if i != cell_id_columns] - clusters_dict = { - cluster: { - j: i - for i, j in zip( - range(1, len(sorted(np.unique(clustering_df[cluster].dropna()))) + 1), - sorted(np.unique(clustering_df[cluster].dropna())), - ) - } - for cluster in clusters_names - } - - new_zarr = zarr.open(storage / (analysis_filename + ".zarr.zip"), mode="w") - new_zarr.create_group("/cell_groups") - - clusters = [ - [clusters_dict[cluster].get(x, 0) for x in list(clustering_df[cluster])] - for cluster in clusters_names - ] - - for i in range(len(clusters)): - new_zarr["cell_groups"].create_group(i) - indices, indptr = get_indices_indptr(np.array(clusters[i])) - new_zarr["cell_groups"][i].create_dataset("indices", data=indices) - new_zarr["cell_groups"][i].create_dataset("indptr", data=indptr) - - new_zarr["cell_groups"].attrs.update( - { - "major_version": 1, - "minor_version": 0, - "number_groupings": len(clusters_names), - "grouping_names": clusters_names, - "group_names": [ - [ - x[0] - for x in sorted(clusters_dict[cluster].items(), key=lambda x: x[1]) - ] - for cluster in clusters_names - ], - } - ) - - new_zarr.store.close() - generate_experiment_file( - template_path=source_path / "experiment.xenium", - output_path=storage / xenium_filename, - cells_name=cells_filename, - analysis_name=analysis_filename, - ) - - -def get_flatten_version(polygons: List[np.ndarray], max_value: int = 21) -> np.ndarray: - """Get the flattened version of polygon vertices. - - Args: - polygons (List[np.ndarray]): List of polygon vertices. - max_value (int): The maximum number of vertices to keep. - - Returns: - np.ndarray: The flattened array of polygon vertices. - """ - n = max_value + 1 - result = np.zeros((len(polygons), n * 2)) - for i, polygon in tqdm(enumerate(polygons), total=len(polygons)): - num_points = len(polygon) - if num_points == 0: - result[i] = np.zeros(n * 2) - continue - elif num_points < max_value: - repeated_points = np.tile(polygon[0], (n - num_points, 1)) - padded_polygon = np.concatenate((polygon, repeated_points), axis=0) - else: - padded_polygon = np.zeros((n, 2)) - padded_polygon[: min(num_points, n)] = polygon[: min(num_points, n)] - padded_polygon[-1] = polygon[0] - result[i] = padded_polygon.flatten() - return result - - def generate_experiment_file( template_path: str, output_path: str, @@ -452,8 +449,8 @@ def generate_experiment_file( with open(template_path) as f: experiment = json.load(f) - experiment["images"].pop("morphology_filepath") - experiment["images"].pop("morphology_focus_filepath") + # experiment["images"].pop("morphology_filepath") + # experiment["images"].pop("morphology_focus_filepath") experiment["xenium_explorer_files"][ "cells_zarr_filepath" @@ -464,4 +461,4 @@ def generate_experiment_file( ] = f"{analysis_name}.zarr.zip" with open(output_path, "w") as f: - json.dump(experiment, f, indent=2) + json.dump(experiment, f, indent=2) \ No newline at end of file From f73708705e596780c1d83324a0c9ef35afc843ba Mon Sep 17 00:00:00 2001 From: Elihei2 Date: Mon, 18 Aug 2025 16:55:21 +0200 Subject: [PATCH 2/3] added triplet loss for ME genes --- .DS_Store | Bin 10244 -> 10244 bytes scripts/batch_run_xenium/create_data_batch.py | 17 +- scripts/batch_run_xenium/predict_batch.sh | 6 +- scripts/batch_run_xenium/train_batch.py | 17 +- scripts/batch_run_xenium/train_batch_BrM.sh | 18 ++ scripts/batch_run_xenium/train_batch_GB.sh | 18 ++ scripts/create_data_fast_sample.py | 30 ++-- scripts/predict_model_sample.py | 31 ++-- scripts/train_mimmo_batch.py | 2 +- scripts/train_model_sample.py | 30 ++-- src/segger/data/parquet/_utils.py | 162 +++++++++++++++++- src/segger/data/parquet/sample.py | 99 ++++++----- src/segger/models/segger_model.py | 16 +- src/segger/prediction/predict_parquet.py | 1 + src/segger/training/train.py | 39 ++++- 15 files changed, 380 insertions(+), 106 deletions(-) create mode 100644 scripts/batch_run_xenium/train_batch_BrM.sh create mode 100644 scripts/batch_run_xenium/train_batch_GB.sh diff --git a/.DS_Store b/.DS_Store index ec2e1f4516b97f824ad79d48cde8596ceca81289..8c1e4cb3dad1c1e8bf5550e4907712f557dbb7ce 100644 GIT binary patch delta 179 zcmZn(XbG6$FUZHhz`)4BAi$85ZWx@LpIfl8a2or>2Eonj94s95AXyd$J%)6KOokGe z3ZNK+gm~i&WVzgY7nh`*{3M_lhbYt20wvkwj;Qh}c;yQ+AhrYbFi5x_08uOqDVv1^ RYWX*_EBs>F+$qA$3;<>9DBA!4 delta 71 zcmZn(XbG6$FUrNhz`)4BAi$7RUR;orlb;0SZ{L`>n0+%l2MY%`NScKqg(06Id2)e7 Y(`Es|?|c&rDmJq#{9@T$Dayb%7 diff --git a/scripts/batch_run_xenium/create_data_batch.py b/scripts/batch_run_xenium/create_data_batch.py index 4f3326b6..dc14d0eb 100644 --- a/scripts/batch_run_xenium/create_data_batch.py +++ b/scripts/batch_run_xenium/create_data_batch.py @@ -9,6 +9,7 @@ import os from pqdm.processes import pqdm from tqdm import tqdm +from segger.data.parquet._utils import find_markers, find_mutually_exclusive_genes def main(): # Set up argument parser @@ -20,7 +21,7 @@ def main(): parser.add_argument('--celltype_column', type=str, default="Annotation_merged", help='Column name for cell types in scRNA-seq data') parser.add_argument('--n_workers', type=int, default=4, help='Number of workers for processing') parser.add_argument('--k_tx', type=int, default=5, help='Number of neighbors for transcript graph') - parser.add_argument('--dist_tx', type=float, default=5.0, help='Distance threshold for transcript graph') + parser.add_argument('--dist_tx', type=float, default=20.0, help='Distance threshold for transcript graph') parser.add_argument('--subsample_frac', type=float, default=0.1, help='Subsampling fraction for scRNA-seq data') args = parser.parse_args() @@ -51,14 +52,24 @@ def main(): # scale_factor=0.5 # this is to shrink the initial seg. masks (used for seg. kit) ) + genes = list(set(scrnaseq.var_names) & set(sample.transcripts_metadata['feature_names'])) + markers = find_markers(scrnaseq[:,genes], cell_type_column=args.celltype_column, pos_percentile=90, neg_percentile=20, percentage=20) + # Find mutually exclusive genes based on scRNAseq data + exclusive_gene_pairs = find_mutually_exclusive_genes( + adata=scrnaseq, + markers=markers, + cell_type_column=args.celltype_column + ) + sample.save( data_dir=segger_data_dir, k_bd=3, dist_bd=15, k_tx=args.k_tx, dist_tx=args.dist_tx, - tile_width=150, - tile_height=150, + k_tx_ex=20, + dist_tx_ex=20, + tile_size=10_000, # Tile size for processing neg_sampling_ratio=5.0, frac=1.0, val_prob=0.3, diff --git a/scripts/batch_run_xenium/predict_batch.sh b/scripts/batch_run_xenium/predict_batch.sh index 6387a2c9..8b2e7079 100644 --- a/scripts/batch_run_xenium/predict_batch.sh +++ b/scripts/batch_run_xenium/predict_batch.sh @@ -8,7 +8,7 @@ OUTPUT_DIR=${4:-"logs"} MODELS_ROOT=${5:-"./models/project24_MNG_pqdm"} BENCHMARKS_ROOT=${6:-"data_tidy/benchmarks/project24_MNG_final"} TRANSCRIPTS_ROOT=${7:-"/omics/odcf/analysis/OE0606_projects/oncolgy_data_exchange/domenico_temp/xenium/xenium_output_files"} -GPU_MEM=${8:-"32.0G"} +GPU_MEM=${8:-"39G"} SYSTEM_MEM=${9:-"200GB"} FORCE=${10:-false} @@ -30,9 +30,10 @@ for SAMPLE in "${ALL_SAMPLES[@]}"; do # Define output file path H5AD_FILE1="${BENCHMARKS_ROOT}/${SAMPLE}/${SAMPLE}_0.75_False_4_7.5_5_3_20250709/segger_adata.h5ad" H5AD_FILE2="${BENCHMARKS_ROOT}/${SAMPLE}/${SAMPLE}_0.75_False_4_7.5_5_3_20250714/segger_adata.h5ad" + H5AD_FILE3="${BENCHMARKS_ROOT}/${SAMPLE}/${SAMPLE}_0.75_False_4_7.5_5_3_20250728/segger_adata.h5ad" # Check if processing should be skipped - if [[ -f "$H5AD_FILE1" || -f "$H5AD_FILE2" ]] && [ "$FORCE" != "true" ]; then + if [[ -f "$H5AD_FILE1" || -f "$H5AD_FILE2" || -f "$H5AD_FILE3" ]] && [ "$FORCE" != "true" ]; then echo "segger_adata.h5ad exists for $SAMPLE, skipping..." ((skipped++)) continue @@ -48,6 +49,7 @@ for SAMPLE in "${ALL_SAMPLES[@]}"; do bsub -o ${OUTPUT_DIR}/segmentation_${SAMPLE}.log \ -e ${OUTPUT_DIR}/segmentation_${SAMPLE}.err \ -gpu "num=1:j_exclusive=yes:gmem=${GPU_MEM}" \ + -m gpu-a100-40gb \ -R "rusage[mem=${SYSTEM_MEM}]" \ -q gpu-debian \ python ../segger_dev/scripts/batch_run_xenium/predict_batch.py \ diff --git a/scripts/batch_run_xenium/train_batch.py b/scripts/batch_run_xenium/train_batch.py index ae377848..3bcbebb5 100644 --- a/scripts/batch_run_xenium/train_batch.py +++ b/scripts/batch_run_xenium/train_batch.py @@ -28,27 +28,30 @@ model = Segger( num_tx_tokens=num_tx_tokens, - init_emb=8, - hidden_channels=32, + init_emb=16, + hidden_channels=64, out_channels=16, heads=4, num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="mean") batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) +# Wrap the model in LitSegger +ls = LitSegger(model=model, align_loss=False) -ls = LitSegger(model=model) +# Initialize the Lightning trainer trainer = Trainer( accelerator="gpu", strategy="auto", - precision="16-mixed", # for 5k change to "32" - devices=4, - max_epochs=200, # check models/project_tag/sample/lightning_logs/metrics.csv to check the performance on the last epoch + precision="32", + devices=1, # set higher number if more gpus are available + max_epochs=500, default_root_dir=models_dir, logger=CSVLogger(models_dir), ) + trainer.fit(ls, datamodule=dm) \ No newline at end of file diff --git a/scripts/batch_run_xenium/train_batch_BrM.sh b/scripts/batch_run_xenium/train_batch_BrM.sh new file mode 100644 index 00000000..5685a713 --- /dev/null +++ b/scripts/batch_run_xenium/train_batch_BrM.sh @@ -0,0 +1,18 @@ +DATA_ROOT="data_tidy/pyg_datasets" # this is the ../ parent folder of where segger tiles (graphs and embeddings) are saved +MODELS_ROOT="models" # this is where the trained models are stored +PROJECT_NAME="BrM" # this is the folder tag for the datasets DATA_ROOT/PROJECT_NAME and the folder in MODELS_ROOT where the models are gonna be saved + +for full_path in "$DATA_ROOT/$PROJECT_NAME"/*; do + if [ -d "$full_path" ]; then + echo "$full_path" + dataset_tag=$(basename "$full_path") + echo "Submitting job for $dataset_tag" + bsub -o logs/train_${dataset_tag}.log \ + -gpu num=1:gmem=30G \ + -R "rusage[mem=100GB]" \ + -q gpu \ + python ../segger_dev/scripts/batch_run_xenium/train_batch.py \ + --data_dir "$DATA_ROOT/$PROJECT_NAME/$dataset_tag" \ + --models_dir "$MODELS_ROOT/$PROJECT_NAME/$dataset_tag" + fi +done \ No newline at end of file diff --git a/scripts/batch_run_xenium/train_batch_GB.sh b/scripts/batch_run_xenium/train_batch_GB.sh new file mode 100644 index 00000000..0d4119c8 --- /dev/null +++ b/scripts/batch_run_xenium/train_batch_GB.sh @@ -0,0 +1,18 @@ +DATA_ROOT="data_tidy/pyg_datasets" # this is the ../ parent folder of where segger tiles (graphs and embeddings) are saved +MODELS_ROOT="models" # this is where the trained models are stored +PROJECT_NAME="Neuronal_Panel" # this is the folder tag for the datasets DATA_ROOT/PROJECT_NAME and the folder in MODELS_ROOT where the models are gonna be saved + +for full_path in "$DATA_ROOT/$PROJECT_NAME"/*; do + if [ -d "$full_path" ]; then + echo "$full_path" + dataset_tag=$(basename "$full_path") + echo "Submitting job for $dataset_tag" + bsub -o logs/train_${dataset_tag}.log \ + -gpu num=1:gmem=20.7G \ + -R "rusage[mem=100GB]" \ + -q gpu-debian \ + python ../segger_dev/scripts/batch_run_xenium/train_batch.py \ + --data_dir "$DATA_ROOT/$PROJECT_NAME/$dataset_tag" \ + --models_dir "$MODELS_ROOT/$PROJECT_NAME/$dataset_tag" + fi +done \ No newline at end of file diff --git a/scripts/create_data_fast_sample.py b/scripts/create_data_fast_sample.py index b7bf6558..308e2705 100644 --- a/scripts/create_data_fast_sample.py +++ b/scripts/create_data_fast_sample.py @@ -42,13 +42,13 @@ XENIUM_DATA_DIR = Path( - "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs" ) -SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_exmax_weights") +SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/xe_bc_rep1_loss_emb2") SCRNASEQ_FILE = Path( - "data_tidy/Human_CRC/scRNAseq.h5ad" + "/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad" ) -CELLTYPE_COLUMN = "Level1" +CELLTYPE_COLUMN = "celltype_minor" scrnaseq = sc.read(SCRNASEQ_FILE) sc.pp.subsample(scrnaseq, 0.25) scrnaseq.var_names_make_unique() @@ -62,7 +62,7 @@ -# markers = find_markers(scrnaseq, cell_type_column="Level1", pos_percentile=20, neg_percentile=20, percentage=50) +# markers = find_markers(scrnaseq, cell_type_column="celltype_minor", pos_percentile=20, neg_percentile=20, percentage=50) @@ -70,7 +70,7 @@ # Initialize spatial transcriptomics sample object sample = STSampleParquet( base_dir=XENIUM_DATA_DIR, - n_workers=8, + n_workers=10, sample_type="xenium", # scale_factor=0.8, weights=gene_celltype_abundance_embedding @@ -80,10 +80,12 @@ genes = list(set(scrnaseq.var_names) & set(sample.transcripts_metadata['feature_names'])) -markers = find_markers(scrnaseq[:,genes], cell_type_column="Level1", pos_percentile=90, neg_percentile=20, percentage=20) +markers = find_markers(scrnaseq[:,genes], cell_type_column="celltype_minor", pos_percentile=90, neg_percentile=20, percentage=20) # Find mutually exclusive genes based on scRNAseq data exclusive_gene_pairs = find_mutually_exclusive_genes( - adata=scrnaseq, markers=markers, cell_type_column="Level1" + adata=scrnaseq, + markers=markers, + cell_type_column="celltype_minor" ) @@ -91,15 +93,17 @@ data_dir=SEGGER_DATA_DIR, k_bd=3, # Number of boundary points to connect dist_bd=15, # Maximum distance for boundary connections - k_tx=10, # Use calculated optimal transcript neighbors + k_tx=20, # Use calculated optimal transcript neighbors dist_tx=5, # Use calculated optimal search radius - tile_size=20_000, # Tile size for processing + k_tx_ex=20, # Use calculated optimal transcript neighbors + dist_tx_ex=20, # Use calculated optimal search radius + tile_size=10_000, # Tile size for processing # tile_height=100, - # neg_sampling_ratio=5.0, # 5:1 negative:positive samples + neg_sampling_ratio=10.0, # 5:1 negative:positive samples frac=1.0, # Use all data val_prob=0.3, # 30% validation set test_prob=0, # No test set - k_tx_ex=100, # Use calculated optimal transcript neighbors - dist_tx_ex=20, # Use calculated optimal search radius + # k_tx_ex=100, # Use calculated optimal transcript neighbors + # dist_tx_ex=20, # Use calculated optimal search radius mutually_exclusive_genes=exclusive_gene_pairs ) diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index ec140e06..b32b425e 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -8,6 +8,7 @@ import dask.dataframe as dd import pandas as pd from pathlib import Path +import torch os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["CUPY_CACHE_DIR"] = "./.cupy" @@ -18,23 +19,22 @@ -seg_tag = "human_CRC_seg_nuclei" -model_version = 0 +seg_tag = "xe_bc_rep1_loss_emb2" +model_version = 6 XENIUM_DATA_DIR = Path( - "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs" ) -SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") +SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/xe_bc_rep1_loss_emb2") models_dir = Path("./models") / seg_tag benchmarks_dir = Path( - "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_nuclei" -) -transcripts_file = ( - "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real/transcripts.parquet" + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_bc_rep1_loss_emb2" ) +transcripts_file = XENIUM_DATA_DIR / "transcripts.parquet" + # Initialize the Lightning data module dm = SeggerDataModule( data_dir=SEGGER_DATA_DIR, @@ -44,12 +44,23 @@ dm.setup() +batch = dm.train[0] + # Load in latest checkpoint model_path = models_dir / "lightning_logs" / f"version_{model_version}" model = load_model(model_path / "checkpoints") -receptive_field = {"k_bd": 4, "dist_bd": 10, "k_tx": 5, "dist_tx": 3} + + +# batch = batch.to(f"cuda:0") +# model = model.model.to(f"cuda:0") +# out = model(batch.x_dict, batch.edge_index_dict) +# torch.save(out, 'embeddings/outs_0.pt') +# ids = batch['tx'].id +# torch.save(ids, 'embeddings/ids_0.pt') + +receptive_field = {"k_bd": 4, "dist_bd": 7.5, "k_tx": 5, "dist_tx": 3} segment( model, @@ -63,7 +74,7 @@ score_cut=0.4, # max_transcripts=1500, cell_id_col="segger_cell_id", - use_cc=False, + use_cc=True, knn_method="kd_tree", verbose=True, gpu_ids=["0"], diff --git a/scripts/train_mimmo_batch.py b/scripts/train_mimmo_batch.py index 7fe41e5a..6bbc17f0 100644 --- a/scripts/train_mimmo_batch.py +++ b/scripts/train_mimmo_batch.py @@ -49,7 +49,7 @@ hidden_channels=32, out_channels=16, heads=4, - num_mid_layers=3, + num_mid_layers=0, ) model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index 063f8f93..d7024f8b 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -17,8 +17,10 @@ -segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_exmax_weights") -models_dir = Path("./models/human_CRC_seg_exmax_weights") + + +segger_data_dir = Path("data_tidy/pyg_datasets/xe_bc_rep1_loss_emb2") +models_dir = Path("./models/xe_bc_rep1_loss_emb2") # Base directory to store Pytorch Lightning models # models_dir = Path('models') @@ -26,7 +28,7 @@ # Initialize the Lightning data module dm = SeggerDataModule( data_dir=segger_data_dir, - batch_size=2, + batch_size=1, num_workers=2, ) @@ -36,27 +38,29 @@ # num_tx_tokens = 500 # # If you use custom gene embeddings, use the following two lines instead: -# is_token_based = False -# num_tx_tokens = ( -# dm.train[0].x_dict["tx"].shape[1] -# ) # Set the number of tokens to the number of genes +is_token_based = False +num_tx_tokens = ( + dm.train[0].x_dict["tx"].shape[1] +) # Set the number of tokens to the number of genes model = Segger( - # is_token_based=is_token_based, - num_tx_tokens= 500, - init_emb=8, + num_tx_tokens=num_tx_tokens, + # num_tx_tokens= 600, + init_emb=16, hidden_channels=64, out_channels=16, heads=4, - num_mid_layers=2, + num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="mean") batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) # Wrap the model in LitSegger -ls = LitSegger(model=model, align_loss=True, align_lambda=1) +ls = LitSegger(model=model, align_loss=True, align_lambda=.5, cycle_length=10000) + + # # Initialize the Lightning model # ls = LitSegger( diff --git a/src/segger/data/parquet/_utils.py b/src/segger/data/parquet/_utils.py index 51cfc49a..a200b4c3 100644 --- a/src/segger/data/parquet/_utils.py +++ b/src/segger/data/parquet/_utils.py @@ -5,7 +5,7 @@ from pyarrow import parquet as pq import numpy as np import scipy as sp -from typing import Optional, List, Dict, Tuple +from typing import Optional, List, Dict, Tuple, Set, Sequence import sys from types import SimpleNamespace from pathlib import Path @@ -639,7 +639,165 @@ def find_mutually_exclusive_genes( mutually_exclusive_gene_pairs = [ tuple(sorted((gene1, gene2))) for key1, key2 in combinations(filtered_exclusive_genes.keys(), 2) + if key1 != key2 for gene1 in filtered_exclusive_genes[key1] for gene2 in filtered_exclusive_genes[key2] ] - return set(mutually_exclusive_gene_pairs) \ No newline at end of file + return set(mutually_exclusive_gene_pairs) + + + +# def find_mutually_exclusive_genes( +# adata: ad.AnnData, threshold: float = 0.0001 +# ) -> List[Tuple[str, str]]: +# """Identify pairs of genes with coexpression below a specified threshold. + +# Args: +# - adata: AnnData +# Annotated data object containing gene expression data. +# - threshold: float +# Coexpression threshold below which gene pairs are considered. + +# Returns: +# - low_coexpression_pairs: list +# List of gene pairs with low coexpression. +# """ +# gene_expression = adata.to_df() +# genes = gene_expression.columns +# low_coexpression_pairs = [] + +# for gene1, gene2 in combinations(genes, 2): +# expr1 = gene_expression[gene1] > 0 +# expr2 = gene_expression[gene2] > 0 +# coexpression = (expr1 * expr2).mean() + +# if coexpression < threshold * (expr1.mean() + expr2.mean()): +# low_coexpression_pairs.append(tuple(sorted((gene1, gene2)))) + +# return set(low_coexpression_pairs) + + + +# def find_mutually_exclusive_genes( +# adata: ad.AnnData, +# *, +# threshold: float = 1e-4, +# expr_cutoff: float = 0.0, +# block_size: int = 2048, +# ) -> Set[Tuple[str, str]]: +# """ +# Identify gene pairs (i, j) with coexpression below a specified threshold: +# mean( expr_i & expr_j ) < threshold * ( mean(expr_i) + mean(expr_j) ) +# computed via matrix operations on (cells x genes) data. + +# Parameters +# ---------- +# adata : AnnData +# Cells x Genes matrix (adata.X or adata.layers[layer]). +# threshold : float, default 1e-4 +# Coexpression threshold weight for RHS. +# layer : str or None, default None +# Use adata.layers[layer] if provided; otherwise adata.X. +# genes : sequence of str or None +# Optional subset/order of genes to evaluate. +# expr_cutoff : float, default 0.0 +# A cell expresses a gene if value > expr_cutoff. +# block_size : int, default 2048 +# Number of genes per block for the blockwise sparse multiplication +# to control memory (recommended: 1k–10k depending on RAM). + +# Returns +# ------- +# low_coexp_pairs : set of (gene_i, gene_j) +# Gene name pairs with i < j (lexicographic order preserved by indices). +# """ +# # Select matrix and (optionally) subset genes +# layer = None +# genes = None +# X = adata.layers[layer] if layer is not None else adata.X +# if genes is not None: +# adata = adata[:, list(genes)] +# X = adata.layers[layer] if layer is not None else adata.X + +# var_names = adata.var_names +# n_cells, n_genes = adata.n_obs, adata.n_vars + +# # Binarize to a boolean CSR: expressed if > expr_cutoff +# if sp.issparse(X): +# Xb = X.tocsr().astype(np.float32) +# Xb.data = (Xb.data > expr_cutoff).astype(np.uint8) +# Xb.eliminate_zeros() +# else: +# Xb = sp.csr_matrix((X > expr_cutoff).astype(np.uint8)) + +# # Per-gene expression fraction p(gene) = mean over cells +# colsum = np.asarray(Xb.sum(axis=0)).ravel().astype(np.int64) # counts of expressing cells +# p = colsum / float(n_cells) # shape (G,) + +# # We'll scan genes in blocks: for each block B, compute (Xb.T_B @ Xb) -> (B x G) intersection counts +# result_pairs: Set[Tuple[str, str]] = set() +# tN = threshold * n_cells # scale to compare counts on LHS + +# for start in range(0, n_genes, block_size): +# stop = min(start + block_size, n_genes) +# # (cells x B) +# Xb_block = Xb[:, start:stop] # CSR +# # Intersection counts for the block against all genes: (B x G) +# inter_BG = (Xb_block.T @ Xb).tocoo() + +# # Build RHS for the whole block as a dense (B x G) using outer sums p_block[:,None] + p[None,:] +# p_block = p[start:stop] # (B,) +# rhs_BG = tN * (p_block[:, None] + p[None, :]) # dense small block + +# # We need to test ALL pairs i in block, j in [0..G), not just where inter_BG has nonzeros. +# # Strategy: +# # 1) Start by assuming inter_ij = 0 for all pairs in the block (since absent in sparse). +# # For those, condition is: 0 < rhs_BG[i,j] -> typically true unless rhs==0. +# # 2) Then overwrite where we DO have nonzero intersections with the actual counts and re-test. +# # +# # Step 1: zero-intersection candidates (exclude diagonal and ensure i 0 (else the inequality cannot hold). +# cand_mask = zero_mask & (rhs_BG > 0) + +# # Enforce i < j (global indices) +# # Convert to global indices and filter upper triangle +# if np.any(cand_mask): +# rows, cols = np.where(cand_mask) +# gi = rows + start +# gj = cols +# keep = gi < gj +# gi, gj = gi[keep], gj[keep] +# # Add pairs +# for ii, jj in zip(gi, gj): +# result_pairs.add((var_names[ii], var_names[jj])) + +# # Step 2: handle nonzero intersections from inter_BG +# if inter_BG.nnz: +# gi = inter_BG.row + start +# gj = np.asarray(inter_BG.col) +# # Enforce i < j and exclude diagonal +# keep = gi < gj +# gi, gj = gi[keep], gj[keep] +# inter_vals = inter_BG.data[keep].astype(np.float64) + +# # Compare: inter_ij < tN * (p_i + p_j) +# rhs_vals = tN * (p[gi] + p[gj]) +# mask = inter_vals < rhs_vals + +# for ii, jj, ok in zip(gi, gj, mask): +# if ok: +# result_pairs.add((var_names[ii], var_names[jj])) + +# # help GC +# del Xb_block, inter_BG, rhs_BG, zero_mask, cand_mask + +# return result_pairs \ No newline at end of file diff --git a/src/segger/data/parquet/sample.py b/src/segger/data/parquet/sample.py index 0e5e4fd2..f07d8470 100644 --- a/src/segger/data/parquet/sample.py +++ b/src/segger/data/parquet/sample.py @@ -484,6 +484,8 @@ def save_debug( dist_bd: float = 15.0, k_tx: int = 3, dist_tx: float = 5.0, + k_tx_ex: int = 100, + dist_tx_ex: float = 20, tile_width: Optional[float] = None, tile_height: Optional[float] = None, neg_sampling_ratio: float = 5.0, @@ -552,6 +554,8 @@ def save_debug( dist_bd=dist_bd, k_tx=k_tx, dist_tx=dist_tx, + k_tx_ex=k_tx_ex, + dist_tx_ex=dist_tx_ex, neg_sampling_ratio=neg_sampling_ratio, ) @@ -1181,6 +1185,9 @@ def get_boundary_props( props = torch.as_tensor(props.values).float() return props + + def canonical_edges(edge_index): + return torch.sort(edge_index, dim=0)[0] def to_pyg_dataset( self, @@ -1241,23 +1248,46 @@ def to_pyg_dataset( if mutually_exclusive_genes is not None: - nbrs_edge_idx = self.get_kdtree_edge_index( + # Get potential repulsive edges (k-nearest neighbors within distance) + # --- Step 1: Get repulsive edges (mutually exclusive genes) --- + repels_edge_idx = self.get_kdtree_edge_index( self.transcripts[self.settings.transcripts.xyz], self.transcripts[self.settings.transcripts.xyz], k=k_tx_ex, max_distance=dist_tx_ex, ) gene_ids = self.transcripts[self.settings.transcripts.label].tolist() - src_gene = [gene_ids[idx] for idx in nbrs_edge_idx[0].tolist()] - dst_gene = [gene_ids[idx] for idx in nbrs_edge_idx[1].tolist()] - + + # Filter repels_edge_idx to only keep mutually exclusive gene pairs + src_genes = [gene_ids[i] for i in repels_edge_idx[0].tolist()] + dst_genes = [gene_ids[i] for i in repels_edge_idx[1].tolist()] mask = [ - tuple(sorted((a, b))) in mutually_exclusive_genes - for a, b in zip(src_gene, dst_gene) - ] - mask = torch.tensor(mask) - - pyg_data["tx", "excludes", "tx"].edge_index = nbrs_edge_idx[:, mask] + tuple(sorted((a, b))) in mutually_exclusive_genes if a != b else False + for a, b in zip(src_genes, dst_genes) + ] + repels_edge_idx = repels_edge_idx[:, torch.tensor(mask)] + + # --- Step 2: Get attractive edges (same gene, at least one node in repels) --- + # Nodes involved in repels (for filtering nbrs_edge_idx) + repels_nodes = torch.cat([repels_edge_idx[0], repels_edge_idx[1]]).unique() + + # Filter nbrs_edge_idx: keep edges where (1) same gene AND (2) at least one node in repels + attractive_mask = torch.zeros(nbrs_edge_idx.shape[1], dtype=torch.bool) + for i, (src, dst) in enumerate(nbrs_edge_idx.t().tolist()): + if (src != dst) and (gene_ids[src] == gene_ids[dst]) and (src in repels_nodes or dst in repels_nodes): + attractive_mask[i] = True + attractive_edge_idx = nbrs_edge_idx[:, attractive_mask] + + # --- Step 3: Combine repels (label=0) and attractive (label=1) edges --- + edge_label_index = torch.cat([repels_edge_idx, attractive_edge_idx], dim=1) + edge_label = torch.cat([ + torch.zeros(repels_edge_idx.shape[1], dtype=torch.long), # 0 for repels + torch.ones(attractive_edge_idx.shape[1], dtype=torch.long) # 1 for attracts + ]) + + # --- Step 4: Store in PyG data object --- + pyg_data["tx", "attracts", "tx"].edge_label_index = edge_label_index + pyg_data["tx", "attracts", "tx"].edge_label = edge_label # Set up Boundary nodes @@ -1336,39 +1366,24 @@ def to_pyg_dataset( if blng_edge_idx.numel() == 0: return pyg_data - # If there are tx-bd edges, add negative edges for training - pos_edges = blng_edge_idx # shape (2, num_pos) - num_pos = pos_edges.shape[1] - - # Negative edges (tx-neighbors-bd) - EXCLUDE positives - neg_candidates = nbrs_edge_idx # shape (2, num_candidates) - - # --- Fast Negative Filtering (PyTorch-only) --- - # Reshape edges for broadcasting: (2, num_pos) vs (2, num_candidates, 1) - pos_expanded = pos_edges.unsqueeze(2) # shape (2, num_pos, 1) - neg_expanded = neg_candidates.unsqueeze(1) # shape (2, 1, num_candidates) - - # Compare all edges in one go (broadcasting) - matches = (pos_expanded == neg_expanded).all(dim=0) # shape (num_pos, num_candidates) - is_negative = ~matches.any(dim=0) # shape (num_candidates,) - - # Filter negatives - neg_edges = neg_candidates[:, is_negative] # shape (2, num_filtered_neg) - num_neg = neg_edges.shape[1] - - # --- Combine and label --- - edge_label_index = torch.cat([neg_edges, pos_edges], dim=1) - edge_label = torch.cat([ - torch.zeros(num_neg, dtype=torch.float), - torch.ones(num_pos, dtype=torch.float) - ]) + # If there are tx-bd edges, add negative edges for training + transform = RandomLinkSplit( + num_val=0, + num_test=0, + is_undirected=True, + edge_types=[edge_type], + neg_sampling_ratio=neg_sampling_ratio, + ) + pyg_data, _, _ = transform(pyg_data) - mask = edge_label_index[0].unsqueeze(1) == blng_edge_idx[0].unsqueeze(0) + # Refilter negative edges to include only transcripts in the + # original positive edges (still need a memory-efficient solution) + edges = pyg_data[edge_type] + mask = edges.edge_label_index[0].unsqueeze(1) == edges.edge_index[0].unsqueeze( + 0 + ) mask = torch.nonzero(torch.any(mask, 1)).squeeze() - edge_label_index = edge_label_index[:, mask] - edge_label = edge_label[mask] - - pyg_data[edge_type].edge_label_index = edge_label_index - pyg_data[edge_type].edge_label = edge_label + edges.edge_label_index = edges.edge_label_index[:, mask] + edges.edge_label = edges.edge_label[mask] return pyg_data diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index abf6f159..5930e246 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -1,12 +1,18 @@ import torch from torch_geometric.nn import GATv2Conv, Linear -from torch.nn import Embedding -from torch import Tensor -from typing import Union +from torch import Tensor +from typing import Union, Dict, Tuple, Optional +from torch.nn import ( + Embedding, + ModuleList, + Module, + functional as F +) # from torch_sparse import SparseTensor + class Segger(torch.nn.Module): def __init__( self, @@ -93,6 +99,8 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: # Last layer x = self.conv_last(x, edge_index) # + self.lin_last(x) + # x = F.normalize(x) + return x def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor: @@ -106,4 +114,4 @@ def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor: Returns: Tensor: Predicted edge values. """ - return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) + return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) \ No newline at end of file diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index 33ec5521..6613f91c 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -181,6 +181,7 @@ def sort_order(c): # Load model from checkpoint lit_segger = LitSegger.load_from_checkpoint( checkpoint_path=checkpoint_path, + strict=False ) return lit_segger diff --git a/src/segger/training/train.py b/src/segger/training/train.py index 0ff9d549..ce90bedd 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -5,6 +5,7 @@ from torchmetrics import F1Score from lightning import LightningModule from segger.models.segger_model import * +import torch.nn.functional as F class LitSegger(LightningModule): """ @@ -17,7 +18,8 @@ def __init__( model: Segger, learning_rate: float = 1e-3, align_loss: bool = False, - align_lambda: float = 0 + align_lambda: float = 0, + cycle_length: int = 1000, # Steps per cycle (e.g., 1000 steps) ): """ Initialize the Segger training module. @@ -39,9 +41,18 @@ def __init__( self.align_loss = align_loss self.align_lambda = align_lambda self.criterion = torch.nn.BCEWithLogitsLoss() + self.criterion_align = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(6.0)) + self.cycle_length = cycle_length self.validation_step_outputs = [] + def get_cosine_weight(self, step: int) -> float: + """Compute cyclic weight for align_loss using cosine scheduling.""" + # Cosine varies between 0 and align_lambda + weight = (1 + torch.cos(torch.tensor(2 * torch.pi * step / self.cycle_length))) / 2 + return weight + + def forward(self, batch) -> torch.Tensor: """ Forward pass for the batch of data. @@ -85,14 +96,24 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: loss = self.criterion(out_values, edge_label) # Log the training loss self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs) - if self.align_loss: - if self.align_loss: - edge_index = batch["tx", "excludes", "tx"].edge_index - out_values = (z["tx"][edge_index[0]] * z["tx"][edge_index[1]]).sum(-1) - targets = torch.zeros_like(out_values) - align_loss = self.criterion(out_values, targets) - self.log("align_loss", align_loss, prog_bar=True, batch_size=batch.num_graphs) - loss = loss + align_loss * self.align_lambda + if self.align_loss: + edge_index = batch["tx", "attracts", "tx"].edge_label_index + edge_label = batch["tx", "attracts", "tx"].edge_label.float() + pos_weight = (edge_label == 0).sum() / (edge_label == 1).sum() + self.log("pos_weight", pos_weight, prog_bar=True, batch_size=batch.num_graphs) + z_tx = z["tx"] + out_values = (z_tx[edge_index[0]] * z_tx[edge_index[1]]).sum(-1) + align_loss = self.criterion_align(out_values, edge_label) + self.log("align_max", torch.max(out_values), prog_bar=True, batch_size=batch.num_graphs) + self.log("align_min", torch.min(out_values), prog_bar=True, batch_size=batch.num_graphs) + self.log("align_loss", align_loss, prog_bar=True, batch_size=batch.num_graphs) + current_step = self.global_step + align_weight = self.get_cosine_weight(current_step) + # self.log("align_weight", align_weight, prog_bar=True, batch_size=batch.num_graphs) + loss = self.align_lambda * align_loss + (1-self.align_lambda) * loss + # loss = self.align_lambda * align_loss + loss + # loss = align_loss + #TOOD: cosine scheduling -- add self-loops return loss def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: From 4711f83174f978b14a2ff4e7ecd0ad6b10a9389d Mon Sep 17 00:00:00 2001 From: Elyas Heidari <55977725+EliHei2@users.noreply.github.com> Date: Mon, 29 Sep 2025 10:22:38 +0200 Subject: [PATCH 3/3] fixed the bug in sample --- src/segger/data/parquet/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/segger/data/parquet/sample.py b/src/segger/data/parquet/sample.py index f07d8470..212821b7 100644 --- a/src/segger/data/parquet/sample.py +++ b/src/segger/data/parquet/sample.py @@ -1306,7 +1306,7 @@ def to_pyg_dataset( ) # Ensure self.boundaries is a GeoDataFrame with correct geometry - self.boundaries = gpd.GeoDataFrame(self.boundaries.copy(), geometry=polygons) + self.boundaries = gpd.GeoDataFrame(index = polygons.index, geometry=polygons) centroids = polygons.centroid.get_coordinates() pyg_data["bd"].id = polygons.index.to_numpy() pyg_data["bd"].pos = torch.tensor(centroids.values, dtype=torch.float32)