Skip to content

Commit c0d52bc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4c79f3b commit c0d52bc

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tiatoolbox/models/architecture/vanilla.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _get_timm_architecture(
117117
Example:
118118
>>> model = _get_timm_architecture("UNI", pretrained=True)
119119
>>> print(model)
120-
120+
121121
"""
122122
if arch_name in [f"efficientnet_b{i}" for i in range(8)]:
123123
model = timm.create_model(arch_name, pretrained=pretrained)
@@ -298,7 +298,7 @@ def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor:
298298
Returns:
299299
torch.Tensor:
300300
The output logits after passing through the model.
301-
301+
302302
"""
303303
feat = self.feat_extract(imgs)
304304
gap_feat = self.pool(feat)
@@ -319,7 +319,7 @@ def postproc(image: np.ndarray) -> np.ndarray:
319319
Returns:
320320
np.ndarray:
321321
The post-processed image array.
322-
322+
323323
"""
324324
return _postproc(image=image)
325325

@@ -345,7 +345,7 @@ def infer_batch(
345345
Example:
346346
>>> output = _infer_batch(model, batch_data, "cuda")
347347
>>> print(output)
348-
348+
349349
"""
350350
return _infer_batch(model=model, batch_data=batch_data, device=device)
351351

@@ -386,7 +386,7 @@ class TimmModel(ModelABC):
386386
>>> model = TimmModel("UNI", pretrained=True)
387387
>>> output = model(torch.randn(1, 3, 224, 224))
388388
>>> print(output.shape)
389-
389+
390390
"""
391391

392392
def __init__(
@@ -540,7 +540,7 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
540540
Returns:
541541
torch.Tensor:
542542
The extracted features.
543-
543+
544544
"""
545545
feat = self.feat_extract(imgs)
546546
gap_feat = self.pool(feat)
@@ -572,7 +572,7 @@ def infer_batch(
572572
Example:
573573
>>> output = CNNBackbone.infer_batch(model, batch_data, "cuda")
574574
>>> print(output)
575-
575+
576576
"""
577577
return [_infer_batch(model=model, batch_data=batch_data, device=device)]
578578

@@ -660,6 +660,6 @@ def infer_batch(
660660
Example:
661661
>>> output = TimmBackbone.infer_batch(model, batch_data, "cuda")
662662
>>> print(output)
663-
663+
664664
"""
665665
return [_infer_batch(model=model, batch_data=batch_data, device=device)]

0 commit comments

Comments
 (0)