Skip to content

RuntimeError: Could not infer dtype of numpy.uint8 #164

@EmmaSRH

Description

@EmmaSRH

When running the training code, both "uint8" and "float 32" are not accepted by the model:

❌ Training failed: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 171, in collate
    {
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 172, in <dictcomp>
    key: collate(
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 285, in collate_numpy_array_fn
    return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 285, in <listcomp>
    return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
RuntimeError: Could not infer dtype of numpy.uint8

Traceback (most recent call last):
  File "/gaba/u/rushi/pytorch_connectomics/scripts/main.py", line 1261, in main
    trainer.fit(
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 560, in fit
    call._call_and_handle_interrupt(
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 49, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 598, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1011, in _run
    results = self._run_stage()
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1055, in _run_stage
    self.fit_loop.run()
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 216, in run
    self.advance()
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 458, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 152, in run
    self.advance(data_fetcher)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 310, in advance
    batch, _, __ = next(data_fetcher)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 134, in __next__
    batch = super().__next__()
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 61, in __next__
    batch = next(self.iterator)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
    out = next(self._iterator)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
    out[i] = next(self.iterators[i])
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1465, in _next_data
    return self._process_data(data)
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1491, in _process_data
    data.reraise()
  File "/u/rushi/conda-envs/pytc/lib/python3.10/site-packages/torch/_utils.py", line 715, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.

How to fix this? This version uses pytorch_lightning trainer, it is hard to debug.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions