11from ....tools .conversion import r_function
22from ....tools .decorators import method
33
4+ import functools
45import logging
56
6- _alra = r_function ("alra.R" )
7+ _r_alra = r_function ("alra.R" )
78
89log = logging .getLogger ("openproblems" )
910
1011
11- @method (
12- method_name = "ALRA (sqrt norm, reversed normalization)" ,
12+ method_name = ("ALRA (sqrt norm, reversed normalization)" ,)
13+ _alra_method = functools .partial (
14+ method ,
1315 paper_name = "Zero-preserving imputation of scRNA-seq data using "
1416 "low-rank approximation" ,
1517 paper_reference = "linderman2018zero" ,
1618 paper_year = 2018 ,
1719 code_url = "https://github.com/KlugerLab/ALRA" ,
1820 image = "openproblems-r-extras" ,
1921)
20- def alra_sqrt (adata , test = False ):
22+
23+
24+ def _alra (adata , normtype = "log" , reverse_norm_order = False , test = False ):
2125 import numpy as np
2226 import rpy2 .rinterface_lib .embedded
2327 import scprep
2428
25- # libsize and sqrt norm
26- adata .obsm ["train_norm" ] = scprep .utils .matrix_transform (
27- adata .obsm ["train" ], np .sqrt
28- )
29- adata .obsm ["train_norm" ], libsize = scprep .normalize .library_size_normalize (
30- adata .obsm ["train_norm" ], rescale = 1 , return_library_size = True
31- )
32- adata .obsm ["train_norm" ] = adata .obsm ["train_norm" ].tocsr ()
29+ if normtype == "sqrt" :
30+ norm_fn = np .sqrt
31+ denorm_fn = np .square
32+ elif normtype == "log" :
33+ norm_fn = np .log1p
34+ denorm_fn = np .expm1
35+ else :
36+ raise NotImplementedError
37+
38+ X = adata .obsm ["train" ].copy ()
39+ if reverse_norm_order :
40+ # inexplicably, this sometimes performs better
41+ X = scprep .utils .matrix_transform (X , norm_fn )
42+ X , libsize = scprep .normalize .library_size_normalize (
43+ X , rescale = 1 , return_library_size = True
44+ )
45+ else :
46+ X , libsize = scprep .normalize .library_size_normalize (
47+ X , rescale = 1 , return_library_size = True
48+ )
49+ X = scprep .utils .matrix_transform (X , norm_fn )
50+
51+ adata .obsm ["train_norm" ] = X .tocsr ()
3352 # run alra
34- # _alra takes sparse array, returns dense array
53+ # _r_alra takes sparse array, returns dense array
3554 Y = None
3655 attempts = 0
3756 while Y is None :
3857 try :
39- Y = _alra (adata )
58+ Y = _r_alra (adata )
4059 except rpy2 .rinterface_lib .embedded .RRuntimeError : # pragma: no cover
4160 if attempts < 10 :
4261 attempts += 1
@@ -46,57 +65,37 @@ def alra_sqrt(adata, test=False):
4665
4766 # transform back into original space
4867 # functions are reversed!
49- Y = scprep .utils .matrix_transform (Y , np . square )
68+ Y = scprep .utils .matrix_transform (Y , denorm_fn )
5069 Y = scprep .utils .matrix_vector_elementwise_multiply (Y , libsize , axis = 0 )
5170 adata .obsm ["denoised" ] = Y
5271
5372 adata .uns ["method_code_version" ] = "1.0.0"
5473 return adata
5574
5675
57- @method (
58- method_name = "ALRA (log norm)" ,
59- paper_name = "Zero-preserving imputation of scRNA-seq data using "
60- "low-rank approximation" ,
61- paper_reference = "linderman2018zero" ,
62- paper_year = 2018 ,
63- code_url = "https://github.com/KlugerLab/ALRA" ,
64- image = "openproblems-r-extras" ,
76+ @_alra_method (
77+ method_name = "ALRA (sqrt norm, reversed normalization)" ,
6578)
66- def alra_log (adata , test = False ):
67- import numpy as np
68- import rpy2 .rinterface_lib .embedded
69- import scprep
79+ def alra_sqrt_reversenorm (adata , test = False ):
80+ return _alra (adata , normtype = "sqrt" , reverse_norm_order = True , test = False )
7081
71- # libsize and log norm
72- # lib norm
73- adata .obsm ["train_norm" ], libsize = scprep .normalize .library_size_normalize (
74- adata .obsm ["train" ], rescale = 1 , return_library_size = True
75- )
76- # log
77- adata .obsm ["train_norm" ] = scprep .utils .matrix_transform (
78- adata .obsm ["train_norm" ], np .log1p
79- )
80- # to csr
81- adata .obsm ["train_norm" ] = adata .obsm ["train_norm" ].tocsr ()
82- # run alra
83- # _alra takes sparse array, returns dense array
84- Y = None
85- attempts = 0
86- while Y is None :
87- try :
88- Y = _alra (adata )
89- except rpy2 .rinterface_lib .embedded .RRuntimeError : # pragma: no cover
90- if attempts < 10 :
91- attempts += 1
92- log .warning (f"alra.R failed (attempt { attempts } )" )
93- else :
94- raise
9582
96- # transform back into original space
97- Y = scprep .utils .matrix_transform (Y , np .expm1 )
98- Y = scprep .utils .matrix_vector_elementwise_multiply (Y , libsize , axis = 0 )
99- adata .obsm ["denoised" ] = Y
83+ @_alra_method (
84+ method_name = "ALRA (log norm, reversed normalization)" ,
85+ )
86+ def alra_log_reversenorm (adata , test = False ):
87+ return _alra (adata , normtype = "log" , reverse_norm_order = True , test = False )
10088
101- adata .uns ["method_code_version" ] = "1.0.0"
102- return adata
89+
90+ @_alra_method (
91+ method_name = "ALRA (sqrt norm)" ,
92+ )
93+ def alra_sqrt (adata , test = False ):
94+ return _alra (adata , normtype = "sqrt" , reverse_norm_order = False , test = False )
95+
96+
97+ @_alra_method (
98+ method_name = "ALRA (log norm)" ,
99+ )
100+ def alra_log (adata , test = False ):
101+ return _alra (adata , normtype = "log" , reverse_norm_order = False , test = False )
0 commit comments