@@ -84,6 +84,7 @@ def _multi_transform(self, img: torch.Tensor) -> torch.Tensor:
8484 def _get_metrics (self , x : np .ndarray , batch_size : int = 128 ) -> tuple [dict , np .ndarray ]:
8585 """
8686 Calculate similarities that combining label consistency and representation similarity for given samples
87+
8788 :param x: Input samples
8889 :param batch_size: Batch size for processing
8990 :return: A report similarities
@@ -132,6 +133,7 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.n
132133 def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , ** kwargs ) -> None :
133134 """
134135 Determine a threshold that covers 95% of clean samples.
136+
135137 :param x: Clean sample data
136138 :param y: Clean sample labels (not used in this method)
137139 :param batch_size: Batch size for processing
@@ -144,6 +146,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
144146 def detect (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> tuple [dict , np .ndarray ]:
145147 """
146148 Detect whether given samples are adversarial
149+
147150 :param x: Input samples
148151 :param batch_size: Batch size for processing
149152 :return: (report, is_adversarial):
0 commit comments