Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from .trainer import Trainer
from ..utils import cleanup_ddp
from ..utils.data import JetClassDistributedSampler
from ..utils.viz import *
from ..utils.viz import (
plot_roc_curve,
plot_confusion_matrix
)


class JetClassTrainer(Trainer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .trainer import Trainer
from ..utils import cleanup_ddp
from ..utils.data import JetClassDistributedSampler
from ..utils.viz import *


class MaskedModelTrainer(Trainer):
Expand Down
6 changes: 5 additions & 1 deletion MAEs/Hybrid_Transformer_Thanh_Nguyen/src/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
cleanup_ddp
)
from ..utils.data import JetClassDistributedSampler
from ..utils.viz import *
from ..utils.viz import (
plot_roc_curve,
plot_confusion_matrix,
plot_history
)


class Trainer:
Expand Down
18 changes: 17 additions & 1 deletion MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
from .viz import *
from .viz import (
plot_feature_distribution,
plot_particle_reconstruction,
plot_history,
plot_ssl_history,
plot_confusion_matrix,
plot_roc_curve
)

__all__ = [
'plot_feature_distribution',
'plot_particle_reconstruction',
'plot_history',
'plot_ssl_history',
'plot_confusion_matrix',
'plot_roc_curve'
]
10 changes: 10 additions & 0 deletions MAEs/Hybrid_Transformer_Thanh_Nguyen/src/utils/viz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score


__all__ = [
'plot_feature_distribution',
'plot_particle_reconstruction',
'plot_history',
'plot_ssl_history',
'plot_confusion_matrix',
'plot_roc_curve'
]


# Function to visualize the distribution of particle features (pT, eta, phi, E)
def plot_feature_distribution(X_jets: np.ndarray) -> None:
feature_names = ['pT', 'eta', 'phi', 'energy']
Expand Down