Skip to content

Commit ac94134

Browse files
committed
fix: fix train dataset tensor conversion
1 parent 224d263 commit ac94134

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

cellseg_models_pytorch/torch_datasets/folder_dataset_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.utils.data import Dataset
77

88
from cellseg_models_pytorch.transforms.albu_transforms import ApplyEach
9-
from cellseg_models_pytorch.utils import FileHandler
9+
from cellseg_models_pytorch.utils import FileHandler, to_tensor
1010

1111
try:
1212
import albumentations as A
@@ -118,8 +118,8 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
118118

119119
tr = self.transforms(image=data["image"], masks=[masks])
120120

121-
image = tr["image"].to(self.output_device)
122-
masks = tr["masks"][0].to(self.output_device)
121+
image = to_tensor(tr["image"])
122+
masks = to_tensor(tr["masks"][0])
123123
masks = torch.split(masks, mask_chls, dim=0)
124124

125125
integer_masks = {k: masks[i] for i, k in enumerate(self.mask_keys)}

cellseg_models_pytorch/torch_datasets/hdf5_dataset_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.utils.data import Dataset
66

77
from cellseg_models_pytorch.transforms.albu_transforms import ApplyEach
8-
from cellseg_models_pytorch.utils import FileHandler
8+
from cellseg_models_pytorch.utils import FileHandler, to_tensor
99

1010
try:
1111
import albumentations as A
@@ -125,8 +125,8 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
125125

126126
tr = self.transforms(image=data["image"], masks=[masks])
127127

128-
image = tr["image"].to(self.output_device)
129-
masks = tr["masks"][0].to(self.output_device)
128+
image = to_tensor(tr["image"])
129+
masks = to_tensor(tr["masks"][0])
130130
masks = torch.split(masks, mask_chls, dim=0)
131131

132132
integer_masks = {k: masks[i] for i, k in enumerate(self.mask_keys)}

0 commit comments

Comments
 (0)