Skip to content

Commit a6285d8

Browse files
committed
Linting and formatting
1 parent 03fe219 commit a6285d8

30 files changed

+384
-366
lines changed

pf2rnaseq/ParameterOpt.py

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,45 @@
22
Hyperparameter sweep for Pf2 using Weights & Biases
33
Optimizing rank and regularization parameter
44
"""
5-
import os
5+
66
import numpy as np
7-
import pandas as pd
8-
import anndata
97
import wandb
10-
from tensorly.cp_tensor import CPTensor
11-
from tlviz.factor_tools import factor_match_score as fms
12-
import matplotlib.pyplot as plt
13-
import seaborn as sns
14-
158
from factorization import pf2
169
from imports import import_cytokine
17-
10+
from tensorly.cp_tensor import CPTensor
11+
from tlviz.factor_tools import factor_match_score as fms
1812

1913
ranks = np.arange(1, 31)
2014
# Define the sweep configuration
2115
sweep_config = {
22-
'method': 'grid', # grid search for thorough exploration
23-
'metric': {
24-
'name': 'fms', # optimize for factor match score
25-
'goal': 'maximize' # we want to maximize factor stability
16+
"method": "grid", # grid search for thorough exploration
17+
"metric": {
18+
"name": "fms", # optimize for factor match score
19+
"goal": "maximize", # we want to maximize factor stability
2620
},
27-
'parameters': {
28-
'rank': {
29-
'values': ranks # Different component numbers to test
21+
"parameters": {
22+
"rank": {
23+
"values": ranks # Different component numbers to test
3024
},
31-
'regParam': {
32-
'values': [0.0, 1e-6, 1e-5, 5e-5, 1e-4] # Different L1 regularization strengths
33-
}
34-
}
25+
"regParam": {
26+
"values": [
27+
0.0,
28+
1e-6,
29+
1e-5,
30+
5e-5,
31+
1e-4,
32+
] # Different L1 regularization strengths
33+
},
34+
},
3535
}
3636

37+
3738
def resample(data):
3839
"""Bootstrapping dataset"""
3940
indices = np.random.randint(0, data.shape[0], size=(data.shape[0],))
4041
return data[indices].copy()
4142

43+
4244
def calculateFMS(A, B):
4345
"""Calculates FMS between 2 factorizations"""
4446
A_factors = [A.uns["Pf2_A"], A.uns["Pf2_B"], A.varm["Pf2_C"]]
@@ -49,78 +51,82 @@ def calculateFMS(A, B):
4951

5052
return fms(A_CP, B_CP, consider_weights=False, skip_mode=1)
5153

54+
5255
def calculate_sparsity(matrix, threshold=1e-6):
5356
"""Calculate sparsity (proportion of near-zero elements)"""
5457
total_elements = matrix.size
5558
near_zero_elements = np.sum(np.abs(matrix) < threshold)
5659
return near_zero_elements / total_elements
5760

61+
5862
def train():
5963
"""Main training function for wandb sweep"""
6064
# Initialize a new wandb run
6165
with wandb.init() as run:
6266
# Get parameters from wandb
6367
config = wandb.config
64-
68+
6569
# Load data (do this once per run to save time)
6670
X = import_cytokine()
6771
print(f"Running with rank={config.rank}, regParam={config.regParam}")
68-
72+
6973
# Set number of bootstrap samples
7074
n_bootstrap = 3
71-
75+
7276
# Run base factorization with current parameters
73-
base_model, r2x = pf2(X,
74-
rank=config.rank,
75-
random_state=42,
76-
doEmbedding=False,
77-
regParam=config.regParam,
78-
r2x=True)
79-
80-
77+
base_model, r2x = pf2(
78+
X,
79+
rank=config.rank,
80+
random_state=42,
81+
doEmbedding=False,
82+
regParam=config.regParam,
83+
r2x=True,
84+
)
85+
8186
sparsity_C = calculate_sparsity(base_model.varm["Pf2_C"])
82-
83-
87+
8488
# Log R2X and sparsity metrics
85-
wandb.log({
86-
"r2x": r2x,
87-
88-
"sparsity_C": sparsity_C
89-
90-
})
91-
89+
wandb.log({"r2x": r2x, "sparsity_C": sparsity_C})
90+
9291
# Calculate FMS across bootstrap samples
9392
fms_scores = []
9493
for i in range(n_bootstrap):
9594
# Create bootstrap sample
9695
bootstrap_data = resample(X)
97-
96+
9897
# Run factorization on bootstrap sample
99-
bootstrap_model = pf2(bootstrap_data,
100-
rank=config.rank,
101-
random_state=i,
102-
doEmbedding=False,
103-
regParam=config.regParam)
104-
98+
bootstrap_model = pf2(
99+
bootstrap_data,
100+
rank=config.rank,
101+
random_state=i,
102+
doEmbedding=False,
103+
regParam=config.regParam,
104+
)
105+
105106
# Calculate FMS between base model and bootstrap model
106107
fms_score = calculateFMS(base_model, bootstrap_model)
107108
fms_scores.append(fms_score)
108-
109+
109110
# Log individual bootstrap FMS
110111
wandb.log({f"fms_bootstrap_{i}": fms_score})
111-
112+
112113
# Calculate and log average FMS
113114
avg_fms = np.mean(fms_scores)
114115
wandb.log({"fms": avg_fms})
115-
116-
print(f"Completed run: rank={config.rank}, regParam={config.regParam}, R2X={r2x:.4f}, FMS={avg_fms:.4f}")
116+
117+
print(
118+
f"Completed run: rank={config.rank}, regParam={config.regParam}, R2X={r2x:.4f}, FMS={avg_fms:.4f}"
119+
)
120+
117121

118122
if __name__ == "__main__":
119123
# Initialize wandb
120124
wandb.login()
121-
125+
122126
# Create the sweep
123127
sweep_id = wandb.sweep(sweep_config, project="Pf2_parameter_optimization2")
124-
128+
125129
# Run the sweep
126-
wandb.agent(sweep_id, function=train, count=None) # Set count if you want to limit runs
130+
wandb.agent(
131+
sweep_id, function=train, count=None
132+
) # Set count if you want to limit runs

pf2rnaseq/factorization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
import anndata
2+
import cupy
3+
import numpy as np
4+
import scipy.sparse as sps
15
from pacmap import PaCMAP
2-
from sklearn.linear_model import LinearRegression
3-
from scipy.stats import gmean
46
from parafac2.parafac2 import parafac2_nd, store_pf2
7+
from scipy.stats import gmean
58
from sklearn.decomposition import PCA
6-
import anndata
7-
import scipy.sparse as sps
8-
import numpy as np
9+
from sklearn.linear_model import LinearRegression
910
from tqdm import tqdm
10-
import cupy
1111

1212

1313
def correct_conditions(X: anndata.AnnData):

pf2rnaseq/figures/common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
This file contains functions that are used in multiple figures.
33
"""
44

5-
from string import ascii_letters
65
import sys
76
import time
8-
import seaborn as sns
7+
from string import ascii_letters
8+
99
import matplotlib
10+
import seaborn as sns
11+
from matplotlib import gridspec
12+
from matplotlib import pyplot as plt
1013
from matplotlib.figure import Figure
11-
from matplotlib import gridspec, pyplot as plt
12-
1314

1415
matplotlib.use("AGG")
1516

pf2rnaseq/figures/commonFuncs/plotFactors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
from typing import Optional
2-
from anndata import AnnData
1+
32
import numpy as np
43
import pandas as pd
4+
import scipy.cluster.hierarchy as sch
55
import seaborn as sns
6+
from anndata import AnnData
67
from matplotlib import pyplot as plt
7-
import scipy.cluster.hierarchy as sch
8-
from matplotlib.patches import Patch
98
from matplotlib.axes import Axes
9+
from matplotlib.patches import Patch
1010

1111
cmap = sns.diverging_palette(240, 10, as_cmap=True)
1212

1313

1414
def plot_condition_factors(
1515
data: AnnData,
1616
ax: Axes,
17-
cond_group_labels: Optional[pd.Series] = None,
17+
cond_group_labels: pd.Series | None = None,
1818
groupConditions=False,
1919
cond="Condition",
2020
):
@@ -82,8 +82,8 @@ def plot_condition_factors(
8282
def plot_condition_factors_groups(
8383
data: AnnData,
8484
ax: Axes,
85-
cond_group_labels: Optional[pd.Series] = None,
86-
subgroup_labels: Optional[pd.Series] = None,
85+
cond_group_labels: pd.Series | None = None,
86+
subgroup_labels: pd.Series | None = None,
8787
groupConditions=False,
8888
cond="Condition",
8989
main_group_title="Treatment",

pf2rnaseq/figures/figureCITEseq1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
and ratio of condition components based on days
44
"""
55

6+
import anndata
7+
import numpy as np
68
from anndata import read_h5ad
79
from matplotlib.axes import Axes
8-
import anndata
9-
from .common import subplotLabel, getSetup
10+
11+
from .common import getSetup, subplotLabel
1012
from .commonFuncs.plotFactors import (
1113
plot_condition_factors,
1214
plot_eigenstate_factors,
13-
plot_gene_factors,
1415
plot_factor_weight,
16+
plot_gene_factors,
1517
)
1618
from .commonFuncs.plotPaCMAP import plot_labels_pacmap
17-
import numpy as np
1819

1920

2021
def makeFigure():

pf2rnaseq/figures/figureCITEseq2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44

55
from anndata import read_h5ad
6+
67
from .common import (
7-
subplotLabel,
88
getSetup,
9+
subplotLabel,
910
)
1011
from .commonFuncs.plotPaCMAP import plot_wp_pacmap, plot_wp_per_celltype
1112

pf2rnaseq/figures/figureCITEseq3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44

55
from anndata import read_h5ad
6+
67
from .common import (
7-
subplotLabel,
88
getSetup,
9+
subplotLabel,
910
)
1011
from .commonFuncs.plotPaCMAP import plot_gene_pacmap
1112

pf2rnaseq/figures/figureCITEseq4.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44

55
from anndata import read_h5ad
6+
67
from .common import (
7-
subplotLabel,
88
getSetup,
9+
subplotLabel,
910
)
1011
from .commonFuncs.plotFactors import plot_gene_factors_partial
1112

pf2rnaseq/figures/figureCITEseq5.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44

55
from anndata import read_h5ad
6+
67
from .common import (
7-
subplotLabel,
88
getSetup,
9+
subplotLabel,
910
)
1011
from .commonFuncs.plotFactors import bot_top_genes
1112
from .commonFuncs.plotGeneral import plot_avegene_per_celltype
@@ -19,16 +20,16 @@ def makeFigure():
1920
# Add subplot labels
2021
subplotLabel(ax)
2122

22-
#X = read_h5ad("/opt/pf2/CITEseq_fitted_annotated.h5ad", backed="r")
23-
X = read_h5ad("/home/nicoleb/Pf2-scRNAseq-1/pf2rnaseq/Cytokine_Pf2_annotated_NB_031725.h5ad")
24-
25-
23+
# X = read_h5ad("/opt/pf2/CITEseq_fitted_annotated.h5ad", backed="r")
24+
X = read_h5ad(
25+
"/home/nicoleb/Pf2-scRNAseq-1/pf2rnaseq/Cytokine_Pf2_annotated_NB_031725.h5ad"
26+
)
2627

27-
comps = [1,12,30]
28+
comps = [1, 12, 30]
2829
genes = bot_top_genes(X, cmp=comps[1], geneAmount=10)
2930

3031
for i, gene in enumerate(genes):
3132
plot_avegene_per_celltype(X, gene, ax[i], cellType="CellType2")
32-
#ax[1].get_legend().remove()
33+
# ax[1].get_legend().remove()
3334

3435
return f

pf2rnaseq/figures/figureCITEseq6.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
CITEseq: Cell type percentage per Leiden cluster per condition
33
"""
44

5-
from anndata import read_h5ad
6-
from .common import subplotLabel, getSetup
75
import seaborn as sns
6+
from anndata import read_h5ad
7+
8+
from .common import getSetup, subplotLabel
89
from .commonFuncs.plotGeneral import cell_count_perc_df
910

1011

0 commit comments

Comments
 (0)