Skip to content

Commit c05f7c1

Browse files
authored
Merge pull request #346 from jkmckenna/0.3.2
format/lint
2 parents 37dca67 + 97db9e7 commit c05f7c1

File tree

5 files changed

+97
-96
lines changed

5 files changed

+97
-96
lines changed

src/smftools/cli/chimeric_adata.py

Lines changed: 41 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,7 @@ def chimeric_adata_core(
403403
max_segments_per_read=getattr(
404404
cfg, "rolling_nn_zero_pairs_max_segments_per_read", None
405405
),
406-
max_segment_overlap=getattr(
407-
cfg, "rolling_nn_zero_pairs_max_overlap", None
408-
),
406+
max_segment_overlap=getattr(cfg, "rolling_nn_zero_pairs_max_overlap", None),
409407
)
410408
adata.uns.setdefault(
411409
f"{cfg.rolling_nn_obsm_key}_zero_pairs_map", {}
@@ -457,19 +455,15 @@ def chimeric_adata_core(
457455
dtype=object,
458456
)
459457
if not filtered_df.empty:
460-
for read_id, read_df in filtered_df.groupby(
461-
"read_id", sort=False
462-
):
458+
for read_id, read_df in filtered_df.groupby("read_id", sort=False):
463459
read_index = int(read_id)
464460
if read_index < 0 or read_index >= subset.n_obs:
465461
continue
466462
tuples = _build_top_segments_obs_tuples(
467463
read_df,
468464
subset.obs_names,
469465
)
470-
per_read_obs_series.at[
471-
subset.obs_names[read_index]
472-
] = tuples
466+
per_read_obs_series.at[subset.obs_names[read_index]] = tuples
473467
adata.obs[per_read_obs_key] = per_read_obs_series
474468
_build_zero_hamming_span_layer_from_obs(
475469
adata=adata,
@@ -536,10 +530,7 @@ def chimeric_adata_core(
536530
parent_obsm_key
537531
)
538532
out_png = rolling_nn_dir / f"{safe_sample}__{safe_ref}.png"
539-
title = (
540-
f"{sample} {reference} (n={subset.n_obs})"
541-
f" | window={cfg.rolling_nn_window}"
542-
)
533+
title = f"{sample} {reference} (n={subset.n_obs}) | window={cfg.rolling_nn_window}"
543534
try:
544535
plot_rolling_nn_and_layer(
545536
subset,
@@ -748,8 +739,7 @@ def chimeric_adata_core(
748739

749740
out_png = rolling_nn_layers_dir / f"{safe_sample}__{safe_ref}.png"
750741
title = (
751-
f"{sample} {reference} (n={subset.n_obs})"
752-
f" | window={cfg.rolling_nn_window}"
742+
f"{sample} {reference} (n={subset.n_obs}) | window={cfg.rolling_nn_window}"
753743
)
754744
try:
755745
plot_rolling_nn_and_two_layers(
@@ -789,11 +779,7 @@ def chimeric_adata_core(
789779
.astype("category")
790780
.cat.categories.tolist()
791781
)
792-
references = (
793-
adata.obs[cfg.reference_column]
794-
.astype("category")
795-
.cat.categories.tolist()
796-
)
782+
references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
797783
rng = np.random.RandomState(getattr(cfg, "cross_sample_random_seed", 42))
798784

799785
for reference in references:
@@ -814,22 +800,15 @@ def chimeric_adata_core(
814800
site_mask = mod_site_mask & adata.var[position_col].fillna(False)
815801

816802
for sample in samples:
817-
sample_mask = (
818-
(adata.obs[cfg.sample_name_col_for_plotting] == sample) & ref_mask
819-
)
803+
sample_mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & ref_mask
820804
if not sample_mask.any():
821805
continue
822806

823807
# Build cross-sample pool
824808
grouping_col = getattr(cfg, "cross_sample_grouping_col", None)
825809
if grouping_col and grouping_col in adata.obs.columns:
826-
sample_group_val = (
827-
adata.obs.loc[sample_mask, grouping_col].iloc[0]
828-
)
829-
pool_mask = (
830-
ref_mask
831-
& (adata.obs[grouping_col] == sample_group_val)
832-
)
810+
sample_group_val = adata.obs.loc[sample_mask, grouping_col].iloc[0]
811+
pool_mask = ref_mask & (adata.obs[grouping_col] == sample_group_val)
833812
else:
834813
pool_mask = ref_mask
835814

@@ -856,7 +835,7 @@ def chimeric_adata_core(
856835

857836
# Build sample_labels: 0 = current sample, 1 = other
858837
cross_labels = np.zeros(len(combined_indices), dtype=np.int32)
859-
cross_labels[len(sample_indices):] = 1
838+
cross_labels[len(sample_indices) :] = 1
860839

861840
cross_obsm_key = "cross_sample_rolling_nn_dist"
862841
try:
@@ -995,18 +974,23 @@ def chimeric_adata_core(
995974

996975
# Copy cross-sample obsm into plot_subset
997976
if parent_obsm_key in adata.obsm:
998-
plot_subset.obsm[cfg.rolling_nn_obsm_key] = (
999-
adata[sample_mask].obsm.get(parent_obsm_key)
977+
plot_subset.obsm[cfg.rolling_nn_obsm_key] = adata[sample_mask].obsm.get(
978+
parent_obsm_key
1000979
)
1001980
for suffix in (
1002-
"starts", "centers", "window", "step", "min_overlap",
1003-
"return_fraction", "layer",
981+
"starts",
982+
"centers",
983+
"window",
984+
"step",
985+
"min_overlap",
986+
"return_fraction",
987+
"layer",
1004988
):
1005989
parent_key = f"{parent_obsm_key}_{suffix}"
1006990
if parent_key in adata.uns:
1007-
plot_subset.uns[f"{cfg.rolling_nn_obsm_key}_{suffix}"] = (
1008-
adata.uns[parent_key]
1009-
)
991+
plot_subset.uns[f"{cfg.rolling_nn_obsm_key}_{suffix}"] = adata.uns[
992+
parent_key
993+
]
1010994

1011995
if grouping_col and grouping_col in adata.obs.columns:
1012996
cross_pool_desc = f"cross-sample ({grouping_col}={sample_group_val})"
@@ -1109,11 +1093,7 @@ def chimeric_adata_core(
11091093
.astype("category")
11101094
.cat.categories.tolist()
11111095
)
1112-
references = (
1113-
adata.obs[cfg.reference_column]
1114-
.astype("category")
1115-
.cat.categories.tolist()
1116-
)
1096+
references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
11171097

11181098
# Build delta layer: within - cross, clamped at 0
11191099
if (
@@ -1139,13 +1119,9 @@ def chimeric_adata_core(
11391119
for reference in references:
11401120
ref_mask = adata.obs[cfg.reference_column] == reference
11411121
position_col = f"position_in_{reference}"
1142-
site_cols = [
1143-
f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types
1144-
]
1122+
site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
11451123
missing_cols = [
1146-
col
1147-
for col in [position_col, *site_cols]
1148-
if col not in adata.var.columns
1124+
col for col in [position_col, *site_cols] if col not in adata.var.columns
11491125
]
11501126
if missing_cols:
11511127
continue
@@ -1154,9 +1130,8 @@ def chimeric_adata_core(
11541130

11551131
for sample in samples:
11561132
sample_mask = (
1157-
(adata.obs[cfg.sample_name_col_for_plotting] == sample)
1158-
& ref_mask
1159-
)
1133+
adata.obs[cfg.sample_name_col_for_plotting] == sample
1134+
) & ref_mask
11601135
if not sample_mask.any():
11611136
continue
11621137

@@ -1201,14 +1176,17 @@ def chimeric_adata_core(
12011176
(cross_obsm_key, cross_nn_key),
12021177
):
12031178
for suffix in (
1204-
"starts", "centers", "window", "step",
1205-
"min_overlap", "return_fraction", "layer",
1179+
"starts",
1180+
"centers",
1181+
"window",
1182+
"step",
1183+
"min_overlap",
1184+
"return_fraction",
1185+
"layer",
12061186
):
12071187
src_k = f"{src_obsm}_{suffix}"
12081188
if src_k in adata.uns:
1209-
plot_subset.uns[f"{dst_obsm}_{suffix}"] = (
1210-
adata.uns[src_k]
1211-
)
1189+
plot_subset.uns[f"{dst_obsm}_{suffix}"] = adata.uns[src_k]
12121190

12131191
# Check required span layers
12141192
required_layers = [
@@ -1282,20 +1260,14 @@ def chimeric_adata_core(
12821260
.cat.categories.tolist()
12831261
)
12841262
references = (
1285-
adata.obs[cfg.reference_column]
1286-
.astype("category")
1287-
.cat.categories.tolist()
1263+
adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
12881264
)
12891265
for reference in references:
12901266
ref_mask = adata.obs[cfg.reference_column] == reference
12911267
position_col = f"position_in_{reference}"
1292-
site_cols = [
1293-
f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types
1294-
]
1268+
site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
12951269
missing_cols = [
1296-
col
1297-
for col in [position_col, *site_cols]
1298-
if col not in adata.var.columns
1270+
col for col in [position_col, *site_cols] if col not in adata.var.columns
12991271
]
13001272
if missing_cols:
13011273
continue
@@ -1304,9 +1276,8 @@ def chimeric_adata_core(
13041276

13051277
for sample in samples:
13061278
sample_mask = (
1307-
(adata.obs[cfg.sample_name_col_for_plotting] == sample)
1308-
& ref_mask
1309-
)
1279+
adata.obs[cfg.sample_name_col_for_plotting] == sample
1280+
) & ref_mask
13101281
if not sample_mask.any():
13111282
continue
13121283

@@ -1338,9 +1309,7 @@ def chimeric_adata_core(
13381309
exc,
13391310
)
13401311
else:
1341-
logger.debug(
1342-
"Span length distribution: missing required layers, skipping."
1343-
)
1312+
logger.debug("Span length distribution: missing required layers, skipping.")
13441313

13451314
# ============================================================
13461315
# 4) Save AnnData

src/smftools/cli/latent_adata.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ def _build_shared_valid_non_mod_sites_mask(
7575
raise KeyError(f"var_filters not found in adata.var: {missing}")
7676

7777
mod_masks = [np.asarray(adata.var[col].values, dtype=bool) for col in mod_site_cols]
78-
ref_mod_masks.append(mod_masks[0] if len(mod_masks) == 1 else np.logical_or.reduce(mod_masks))
78+
ref_mod_masks.append(
79+
mod_masks[0] if len(mod_masks) == 1 else np.logical_or.reduce(mod_masks)
80+
)
7981

80-
any_mod_mask = np.logical_or.reduce(ref_mod_masks) if ref_mod_masks else np.zeros(
81-
adata.n_vars, dtype=bool
82+
any_mod_mask = (
83+
np.logical_or.reduce(ref_mod_masks) if ref_mod_masks else np.zeros(adata.n_vars, dtype=bool)
8284
)
8385
return np.logical_and(shared_position_mask, np.logical_not(any_mod_mask))
8486

src/smftools/config/experiment_config.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,16 +1405,12 @@ def from_var_dict(
14051405
),
14061406
rolling_nn_zero_pairs_layer_key=merged.get("rolling_nn_zero_pairs_layer_key", None),
14071407
rolling_nn_zero_pairs_refine=merged.get("rolling_nn_zero_pairs_refine", True),
1408-
rolling_nn_zero_pairs_max_nan_run=merged.get(
1409-
"rolling_nn_zero_pairs_max_nan_run", None
1410-
),
1408+
rolling_nn_zero_pairs_max_nan_run=merged.get("rolling_nn_zero_pairs_max_nan_run", None),
14111409
rolling_nn_zero_pairs_merge_gap=merged.get("rolling_nn_zero_pairs_merge_gap", 0),
14121410
rolling_nn_zero_pairs_max_segments_per_read=merged.get(
14131411
"rolling_nn_zero_pairs_max_segments_per_read", None
14141412
),
1415-
rolling_nn_zero_pairs_max_overlap=merged.get(
1416-
"rolling_nn_zero_pairs_max_overlap", None
1417-
),
1413+
rolling_nn_zero_pairs_max_overlap=merged.get("rolling_nn_zero_pairs_max_overlap", None),
14181414
rolling_nn_zero_pairs_layer_overlap_mode=merged.get(
14191415
"rolling_nn_zero_pairs_layer_overlap_mode", "binary"
14201416
),

src/smftools/plotting/chimeric_plotting.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,9 @@ def _format_labels(values):
445445
def _layer_df_for_key(layer_key: str) -> pd.DataFrame:
446446
layer = subset.layers[layer_key]
447447
layer = layer.toarray() if hasattr(layer, "toarray") else np.asarray(layer)
448-
layer_df = pd.DataFrame(layer[valid], index=subset.obs_names[valid], columns=subset.var_names)
448+
layer_df = pd.DataFrame(
449+
layer[valid], index=subset.obs_names[valid], columns=subset.var_names
450+
)
449451
layer_df.index = layer_df.index.astype(str)
450452
if layer_var_mask is not None:
451453
layer_df = layer_df.loc[:, layer_var_mask]
@@ -1108,20 +1110,35 @@ def _nn_df(obsm_key):
11081110
nn_cmap.set_bad(nn_nan_color)
11091111

11101112
sns.heatmap(
1111-
self_nn_ord, ax=ax_self_nn, cmap=nn_cmap,
1112-
xticklabels=False, yticklabels=False, robust=robust, cbar_ax=ax_self_nn_cbar,
1113+
self_nn_ord,
1114+
ax=ax_self_nn,
1115+
cmap=nn_cmap,
1116+
xticklabels=False,
1117+
yticklabels=False,
1118+
robust=robust,
1119+
cbar_ax=ax_self_nn_cbar,
11131120
)
11141121
sns.heatmap(
1115-
cross_nn_ord, ax=ax_cross_nn, cmap=nn_cmap,
1116-
xticklabels=False, yticklabels=False, robust=robust, cbar_ax=ax_cross_nn_cbar,
1122+
cross_nn_ord,
1123+
ax=ax_cross_nn,
1124+
cmap=nn_cmap,
1125+
xticklabels=False,
1126+
yticklabels=False,
1127+
robust=robust,
1128+
cbar_ax=ax_cross_nn_cbar,
11171129
)
11181130

11191131
layer_cmap = plt.get_cmap("coolwarm").copy()
11201132
if read_span_outside is not None:
11211133
layer_cmap.set_bad(outside_read_color)
11221134
sns.heatmap(
1123-
layer_plot, ax=ax_signal, cmap=layer_cmap,
1124-
xticklabels=False, yticklabels=False, robust=robust, cbar_ax=ax_signal_cbar,
1135+
layer_plot,
1136+
ax=ax_signal,
1137+
cmap=layer_cmap,
1138+
xticklabels=False,
1139+
yticklabels=False,
1140+
robust=robust,
1141+
cbar_ax=ax_signal_cbar,
11251142
)
11261143

11271144
# NN x-tick labels
@@ -1181,16 +1198,34 @@ def _nn_df(obsm_key):
11811198
delta_cmap.set_bad(outside_read_color)
11821199

11831200
sns.heatmap(
1184-
self_span_plot, ax=ax_self_span, cmap=self_span_cmap, norm=self_span_norm,
1185-
xticklabels=False, yticklabels=False, robust=robust, cbar_ax=ax_self_span_cbar,
1201+
self_span_plot,
1202+
ax=ax_self_span,
1203+
cmap=self_span_cmap,
1204+
norm=self_span_norm,
1205+
xticklabels=False,
1206+
yticklabels=False,
1207+
robust=robust,
1208+
cbar_ax=ax_self_span_cbar,
11861209
)
11871210
sns.heatmap(
1188-
cross_span_plot, ax=ax_cross_span, cmap=cross_span_cmap, norm=cross_span_norm,
1189-
xticklabels=False, yticklabels=False, robust=robust, cbar_ax=ax_cross_span_cbar,
1211+
cross_span_plot,
1212+
ax=ax_cross_span,
1213+
cmap=cross_span_cmap,
1214+
norm=cross_span_norm,
1215+
xticklabels=False,
1216+
yticklabels=False,
1217+
robust=robust,
1218+
cbar_ax=ax_cross_span_cbar,
11901219
)
11911220
sns.heatmap(
1192-
delta_span_plot, ax=ax_delta_span, cmap=delta_cmap, norm=delta_norm,
1193-
xticklabels=False, yticklabels=False, robust=robust, cbar_ax=ax_delta_span_cbar,
1221+
delta_span_plot,
1222+
ax=ax_delta_span,
1223+
cmap=delta_cmap,
1224+
norm=delta_norm,
1225+
xticklabels=False,
1226+
yticklabels=False,
1227+
robust=robust,
1228+
cbar_ax=ax_delta_span_cbar,
11941229
)
11951230

11961231
col_labels = [str(x) for x in self_span_ord.columns]

tests/unit/test_chimeric_adata_span_layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
import anndata as ad
12
import numpy as np
23
import pandas as pd
34

4-
import anndata as ad
5-
65
from smftools.cli.chimeric_adata import _build_zero_hamming_span_layer_from_obs
76

87

0 commit comments

Comments
 (0)