Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit b1ec8a9

Browse files
rahul-tulibfineran
authored andcommitted
Bugfix: use label smoothing only when torch version is >= 1.10 (#1352)
* Bugfix: use label smoothing only when torch version is >= 1.10 * Apply suggestions from code review
1 parent 6ef1cb4 commit b1ec8a9

File tree

1 file changed

+10
-1
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+10
-1
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
import torch.utils.data
3030
import torchvision
31+
from packaging import version
3132
from torch import nn
3233
from torch.utils.data.dataloader import DataLoader, default_collate
3334
from torchvision.transforms.functional import InterpolationMode
@@ -408,7 +409,15 @@ def collate_fn(batch):
408409
if args.distributed and args.sync_bn:
409410
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
410411

411-
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
412+
if version.parse(torch.__version__) >= version.parse("1.10"):
413+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
414+
elif args.label_smoothing > 0:
415+
raise ValueError(
416+
f"`label_smoothing` not supported for {torch.__version__}, "
417+
f"try upgrading to at-least 1.10"
418+
)
419+
else:
420+
criterion = nn.CrossEntropyLoss()
412421

413422
custom_keys_weight_decay = []
414423
if args.bias_weight_decay is not None:

0 commit comments

Comments
 (0)