diff --git a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/jetclass_trainer.py b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/jetclass_trainer.py index 1a8b7ae..4f6cd43 100644 --- a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/jetclass_trainer.py +++ b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/jetclass_trainer.py @@ -12,7 +12,10 @@ from .trainer import Trainer from ..utils import cleanup_ddp from ..utils.data import JetClassDistributedSampler -from ..utils.viz import * +from ..utils.viz import ( + plot_roc_curve, + plot_confusion_matrix +) class JetClassTrainer(Trainer): diff --git a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/mm_trainer.py b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/mm_trainer.py index c1d2428..4baa669 100644 --- a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/mm_trainer.py +++ b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/mm_trainer.py @@ -11,7 +11,6 @@ from .trainer import Trainer from ..utils import cleanup_ddp from ..utils.data import JetClassDistributedSampler -from ..utils.viz import * class MaskedModelTrainer(Trainer): diff --git a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/trainer.py b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/trainer.py index 45d317b..136ddf8 100644 --- a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/trainer.py +++ b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/trainer.py @@ -30,7 +30,11 @@ cleanup_ddp ) from ..utils.data import JetClassDistributedSampler -from ..utils.viz import * +from ..utils.viz import ( + plot_roc_curve, + plot_confusion_matrix, + plot_history +) class Trainer: diff --git a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/__init__.py b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/__init__.py index 671f5d7..de6fc7c 100644 --- a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/__init__.py +++ b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/__init__.py @@ -1 +1,17 @@ -from .viz import * \ No newline at end of file +from .viz import ( + plot_feature_distribution, + plot_particle_reconstruction, + plot_history, + plot_ssl_history, + plot_confusion_matrix, + plot_roc_curve +) + +__all__ = [ + 'plot_feature_distribution', + 'plot_particle_reconstruction', + 'plot_history', + 'plot_ssl_history', + 'plot_confusion_matrix', + 'plot_roc_curve' +] \ No newline at end of file diff --git a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/viz.py b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/viz.py index 2af62ab..7338525 100644 --- a/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/viz.py +++ b/MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/viz.py @@ -6,6 +6,16 @@ from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score +__all__ = [ + 'plot_feature_distribution', + 'plot_particle_reconstruction', + 'plot_history', + 'plot_ssl_history', + 'plot_confusion_matrix', + 'plot_roc_curve' +] + + # Function to visualize the distribution of particle features (pT, eta, phi, E) def plot_feature_distribution(X_jets: np.ndarray) -> None: feature_names = ['pT', 'eta', 'phi', 'energy']