-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
Hi there! I have previously created my first LightningDataModule. More specifically, a NonGeoDataModule which inherits from there (see torchgeo-fork. Interestingly, when I try to run this module I get RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor. Even more intersting is the fact, that if I override the transfer_batch_to_device like:
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
print("----------------------------------------")
for k in batch.keys(): print(k, batch[k][0].get_device())
print("----------------------------------------")
return batchI get the output
image 0
mask 0
It happens during the validation step (lightning/pytorch/strategies/strategy.py", line 411).
What version are you seeing the problem on?
v2.4
How to reproduce the bug
def train(
config: dict,
data_dir: str=default_data_dir,
root_dir: str=default_root_dir,
min_epochs: int=1,
max_epochs: int=25) -> None:
tune_metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
module = FL(
num_workers=config["num_workers"],
batch_size=config["batch_size"],
patch_size=config["patch_size"],
val_split_pct=0.25,
use_toy=True,
#augs=transforms,
root=data_dir,
)
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
ignore_index=255,
in_channels=5,#(5+3), #appended indices
num_classes=13,
lr=config["lr"],
patience=config["lr_patience"]
)
# Callbacks
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")
lr_monitor = LearningRateMonitor(logging_interval="step")
tune_callback = TuneReportCheckpointCallback(
{"loss": "val_loss", "accuracy": "val_accuracy"}, on="validation_end"
)
logger = TensorBoardLogger(save_dir=root_dir, name="FLAIR2logs")
trainer = Trainer(
accelerator=accelerator,
num_nodes=1,
callbacks=[checkpoint_callback, lr_monitor, tune_callback],
log_every_n_steps=1,
logger=logger,
min_epochs=1,
max_epochs=25,
precision=32,
)
trainer.fit(model=task, datamodule=module)Error messages and logs
Traceback (most recent call last):
File "//Dev/forks/torchgeo/train_simple.py", line 158, in <module>
main()
File "//Dev/forks/torchgeo/train_simple.py", line 154, in main
train(config)
File "//Dev/forks/torchgeo/train_simple.py", line 151, in train
trainer.fit(model=task, datamodule=module)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
results = self._run_stage()
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
self._run_sanity_check()
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
val_loop.run()
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
return loop_run(self, *args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
output = fn(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File "//Dev/forks/torchgeo/torchgeo/trainers/segmentation.py", line 251, in validation_step
y_hat = self(x)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "//Dev/forks/torchgeo/torchgeo/trainers/base.py", line 81, in forward
return self.model(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/base/model.py", line 38, in forward
features = self.encoder(x)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/encoders/resnet.py", line 63, in forward
x = stages[i](x)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
input = module(input)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
return self._conv_forward(input, self.weight, self.bias)
File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
Environment
Current environment
-----------------------------------------------------------
Python Version: 3.10.4
PyTorch Version: 2.4.1
Cuda is available version: 12.4
Torch built with CUDA: True
cuDNN Version: 90100
cuDNN Enabled: True
cuDNN available: True
Device: cuda
Accelerator: gpu
lightning 2.4.0
lightning-utilities 0.11.9
pytorch-lightning 2.4.0
## conda env
name: torchgeo
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.10
- pytorch-cuda=12.4
- pytorch=2.4
- torchgeo=0.6.0
- tensorboard=2.17
-----------------------------------------------------------
More info
No response