@@ -29,7 +29,7 @@ class MetricWrapper(nn.Module):
2929 Methods
3030 -------
3131 __call__(y_true, y_pred)
32- Passes the true and predicted labels to the metric functions.
32+ Passes the true and predicted logits to the metric functions.
3333 getmetrics(str_prefix: str = None)
3434 Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.
3535 resetmetric()
@@ -40,10 +40,13 @@ class MetricWrapper(nn.Module):
4040 >>> from CollaborativeCoding import MetricWrapperProposed
4141 >>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision")
4242 >>> y_true = [0, 1, 0, 1]
43- >>> y_pred = [0, 1, 1, 0]
43+ >>> y_pred = [[0.8, -1.9],
44+ [0.1, 9.0],
45+ [-1.9, -0.1],
46+ [1.9, 1.8]]
4447 >>> metrics(y_true, y_pred)
4548 >>> metrics.getmetrics()
46- {'entropy': 0.6931471805599453 , 'f1': 0.5, 'precision': 0.5}
49+ {'entropy': 0.3292665 , 'f1': 0.5, 'precision': 0.5}
4750 >>> metrics.resetmetric()
4851 >>> metrics.getmetrics()
4952 {'entropy': [], 'f1': [], 'precision': []}
0 commit comments