Skip to content

Commit 9c0ff05

Browse files
committed
fix pytorch yolo docstrings
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 2f1755c commit 9c0ff05

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
154154
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and
155155
0 <= y1 < y2 <= H.
156-
- labels [N]: the labels for each image
156+
- labels [N]: the labels for each image.
157157
- scores [N]: the scores of each prediction.
158158
:param input_shape: The shape of one input sample.
159159
:param optimizer: The optimizer for training the classifier.
@@ -275,8 +275,7 @@ def _preprocess_and_convert_inputs(
275275
The fields of the Dict are as follows:
276276
277277
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
278-
- labels [N]: the labels for each image
279-
- scores [N]: the scores of each prediction.
278+
- labels [N]: the labels for each image.
280279
:param fit: `True` if the function is call before fit/training and `False` if the function is called before a
281280
predict operation.
282281
:param no_grad: `True` if no gradients required.
@@ -365,8 +364,7 @@ def _get_losses(
365364
The fields of the Dict are as follows:
366365
367366
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
368-
- labels [N]: the labels for each image
369-
- scores [N]: the scores of each prediction.
367+
- labels [N]: the labels for each image.
370368
:return: Loss gradients of the same shape as `x`.
371369
"""
372370
self._model.train()
@@ -401,8 +399,7 @@ def loss_gradient( # pylint: disable=W0613
401399
The fields of the Dict are as follows:
402400
403401
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
404-
- labels [N]: the labels for each image
405-
- scores [N]: the scores of each prediction.
402+
- labels [N]: the labels for each image.
406403
:return: Loss gradients of the same shape as `x`.
407404
"""
408405
import torch
@@ -457,7 +454,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
457454
are as follows:
458455
459456
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
460-
- labels [N]: the labels for each image
457+
- labels [N]: the labels for each image.
461458
- scores [N]: the scores of each prediction.
462459
"""
463460
import torch
@@ -521,8 +518,7 @@ def fit( # pylint: disable=W0221
521518
The fields of the Dict are as follows:
522519
523520
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
524-
- labels [N]: the labels for each image
525-
- scores [N]: the scores of each prediction.
521+
- labels [N]: the labels for each image.
526522
:param batch_size: Size of batches.
527523
:param nb_epochs: Number of epochs to use for training.
528524
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
@@ -604,8 +600,7 @@ def compute_losses(
604600
The fields of the Dict are as follows:
605601
606602
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
607-
- labels [N]: the labels for each image
608-
- scores [N]: the scores of each prediction.
603+
- labels [N]: the labels for each image.
609604
:return: Dictionary of loss components.
610605
"""
611606
loss_components, _ = self._get_losses(x=x, y=y)
@@ -625,8 +620,7 @@ def compute_loss( # type: ignore
625620
The fields of the Dict are as follows:
626621
627622
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
628-
- labels [N]: the labels for each image
629-
- scores [N]: the scores of each prediction.
623+
- labels [N]: the labels for each image.
630624
:return: Loss.
631625
"""
632626
import torch

0 commit comments

Comments
 (0)