Skip to content

Commit fe6a588

Browse files
author
Alexander März
committed
Changed dist_select plotting
1 parent bfd869b commit fe6a588

File tree

2 files changed

+27
-53
lines changed

2 files changed

+27
-53
lines changed

lightgbmlss/distributions/distribution_utils.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from tqdm import tqdm
1010

1111
from typing import Any, Dict, Optional, List, Tuple
12-
from plotnine import *
12+
import matplotlib.pyplot as plt
13+
import seaborn as sns
1314
import warnings
1415

1516

@@ -598,7 +599,6 @@ def dist_select(self,
598599
target: np.ndarray,
599600
candidate_distributions: List,
600601
max_iter: int = 100,
601-
n_samples: int = 1000,
602602
plot: bool = False,
603603
figure_size: tuple = (10, 5),
604604
) -> pd.DataFrame:
@@ -614,8 +614,6 @@ def dist_select(self,
614614
List of candidate distributions.
615615
max_iter: int
616616
Maximum number of iterations for the optimization.
617-
n_samples: int
618-
Number of samples to draw from the fitted distribution.
619617
plot: bool
620618
If True, a density plot of the actual and fitted distribution is created.
621619
figure_size: tuple
@@ -650,11 +648,11 @@ def dist_select(self,
650648
}
651649
)
652650
dist_list.append(fit_df)
653-
fit_df = pd.concat(dist_list).sort_values(by=self.loss_fn, ascending=True)
654-
fit_df["rank"] = fit_df[self.loss_fn].rank().astype(int)
655-
fit_df.set_index(fit_df["rank"], inplace=True)
656651
pbar.update(1)
657652
pbar.set_description(f"Fitting of candidate distributions completed")
653+
fit_df = pd.concat(dist_list).sort_values(by=self.loss_fn, ascending=True)
654+
fit_df["rank"] = fit_df[self.loss_fn].rank().astype(int)
655+
fit_df.set_index(fit_df["rank"], inplace=True)
658656

659657
if plot:
660658
# Select best distribution
@@ -675,29 +673,19 @@ def dist_select(self,
675673
axis=1,
676674
)
677675
fitted_params = pd.DataFrame(fitted_params, columns=best_dist_sel.param_dict.keys())
678-
fitted_params.columns = best_dist_sel.param_dict.keys()
676+
n_samples = np.max([10000, target.shape[0]])
677+
n_samples = np.where(n_samples > 500000, 100000, n_samples)
679678
dist_samples = best_dist_sel.draw_samples(fitted_params,
680679
n_samples=n_samples,
681680
seed=123).values
682681

683682
# Plot actual and fitted distribution
684-
plot_df_actual = pd.DataFrame({"y": target.reshape(-1,), "type": "Actual"})
685-
plot_df_fitted = pd.DataFrame({"y": dist_samples.reshape(-1,),
686-
"type": f"Best-Fit: {best_dist['distribution'].values[0]}"})
687-
plot_df = pd.concat([plot_df_actual, plot_df_fitted])
688-
689-
print(
690-
ggplot(plot_df,
691-
aes(x="y",
692-
color="type")) +
693-
geom_density(alpha=0.5) +
694-
theme_bw(base_size=15) +
695-
theme(figure_size=figure_size,
696-
legend_position="right",
697-
legend_title=element_blank(),
698-
plot_title=element_text(hjust=0.5)) +
699-
labs(title=f"Actual vs. Fitted Density")
700-
)
683+
plt.figure(figsize=figure_size)
684+
sns.kdeplot(target.reshape(-1, ), label="Actual")
685+
sns.kdeplot(dist_samples.reshape(-1, ), label=f"Best-Fit: {best_dist['distribution'].values[0]}")
686+
plt.legend()
687+
plt.title("Actual vs. Best-Fit Density", fontweight="bold", fontsize=16)
688+
plt.show()
701689

702690
fit_df.drop(columns=["rank", "params"], inplace=True)
703691

lightgbmlss/distributions/flow_utils.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from tqdm import tqdm
1212

1313
from typing import Any, Dict, Optional, List, Tuple
14-
from plotnine import *
14+
import matplotlib.pyplot as plt
15+
import seaborn as sns
1516
import warnings
1617

1718

@@ -637,7 +638,6 @@ def flow_select(self,
637638
target: np.ndarray,
638639
candidate_flows: List,
639640
max_iter: int = 100,
640-
n_samples: int = 1000,
641641
plot: bool = False,
642642
figure_size: tuple = (10, 5),
643643
) -> pd.DataFrame:
@@ -653,8 +653,6 @@ def flow_select(self,
653653
List of candidate normalizing flow specifications.
654654
max_iter: int
655655
Maximum number of iterations for the optimization.
656-
n_samples: int
657-
Number of samples drawn from the fitted distribution.
658656
plot: bool
659657
If True, a density plot of the actual and fitted distribution is created.
660658
figure_size: tuple
@@ -692,11 +690,11 @@ def flow_select(self,
692690
}
693691
)
694692
flow_list.append(fit_df)
695-
fit_df = pd.concat(flow_list).sort_values(by=flow_sel.loss_fn, ascending=True)
696-
fit_df["rank"] = fit_df[flow_sel.loss_fn].rank().astype(int)
697-
fit_df.set_index(fit_df["rank"], inplace=True)
698693
pbar.update(1)
699694
pbar.set_description(f"Fitting of candidate normalizing flows completed")
695+
fit_df = pd.concat(flow_list).sort_values(by=flow_sel.loss_fn, ascending=True)
696+
fit_df["rank"] = fit_df[flow_sel.loss_fn].rank().astype(int)
697+
fit_df.set_index(fit_df["rank"], inplace=True)
700698

701699
if plot:
702700
# Select normalizing flow with the lowest loss
@@ -713,29 +711,17 @@ def flow_select(self,
713711
flow_params = torch.tensor(best_flow["params"][0]).reshape(1, -1)
714712
flow_dist_sel = best_flow_sel.create_spline_flow(input_dim=1)
715713
_, flow_dist_sel = best_flow_sel.replace_parameters(flow_params, flow_dist_sel)
716-
flow_samples = pd.DataFrame(flow_dist_sel.sample((n_samples,)).squeeze().detach().numpy().T)
714+
n_samples = np.max([10000, target.shape[0]])
715+
n_samples = np.where(n_samples > 500000, 100000, n_samples)
716+
flow_samples = pd.DataFrame(flow_dist_sel.sample((n_samples,)).squeeze().detach().numpy().T).values
717717

718718
# Plot actual and fitted distribution
719-
flow_samples["type"] = f"Best-Fit: {best_flow['NormFlow'].values[0]}"
720-
721-
df_actual = pd.DataFrame(target)
722-
df_actual["type"] = "Data"
723-
724-
plot_df = pd.concat([df_actual, flow_samples]).rename(columns={0: "variable"})
725-
726-
print(
727-
ggplot(plot_df,
728-
aes(x="variable",
729-
color="type")) +
730-
geom_density(size=1.1) +
731-
theme_bw(base_size=15) +
732-
theme(figure_size=figure_size,
733-
legend_position="right",
734-
legend_title=element_blank(),
735-
plot_title=element_text(hjust=0.5)) +
736-
labs(title=f"Actual vs. Fitted Density",
737-
x="")
738-
)
719+
plt.figure(figsize=figure_size)
720+
sns.kdeplot(target.reshape(-1, ), label="Actual")
721+
sns.kdeplot(flow_samples.reshape(-1, ), label=f"Best-Fit: {best_flow['NormFlow'].values[0]}")
722+
plt.legend()
723+
plt.title("Actual vs. Best-Fit Density", fontweight="bold", fontsize=16)
724+
plt.show()
739725

740726
fit_df.drop(columns=["rank", "params"], inplace=True)
741727

0 commit comments

Comments
 (0)