-
Notifications
You must be signed in to change notification settings - Fork 23
utils
antonemanuel edited this page Aug 20, 2021
·
1 revision
Provides utility functions for anomaly detection.
to_batch(images: List[np.ndarray], transforms: T.Compose, device: torch.device) -> torch.TensorConvert a list of numpy array images to a pytorch tensor batch with given transforms.
pytorch_cov(tensor: torch.Tensor, rowvar: bool = True, bias: bool = False) -> torch.TensorEstimate a covariance matrix (np.cov).
mahalanobis(mean: torch.Tensor, cov_inv: torch.Tensor, batch: torch.Tensor) -> torch.TensorCalculate the mahalonobis distance
Calculate the mahalanobis distance between a multivariate normal distribution and a point or elementwise between a set of distributions and a set of points.
Arguments:
-
mean- A mean vector or a set of mean vectors. -
cov_inv- A inverse of covariance matrix or a set of covariance matricies. -
batch- A point or a set of points.
Returns:
-
mahalonobis_distance- A distance or a set of distances or a set of sets of distances.
image_score(patch_scores: torch.Tensor) -> torch.TensorCalculate image scores from patch scores.
Arguments:
-
patch_scores- A batch of patch scores.
Returns:
-
image_scores- A batch of image scores.
classification(image_scores: torch.Tensor, thresh: float) -> torch.TensorCalculate image classifications from image scores.
Arguments:
-
image_scores- A batch of image scores. -
thresh- A treshold value. If an image score is larger than or equal to thresh it is classified as anomalous.
Returns:
-
image_classifications- A batch of image classifcations.