|
13 | 13 | PN2VAlgorithm, |
14 | 14 | ) |
15 | 15 | from careamics.config.architectures import LVAEConfig, UNetConfig |
16 | | -from careamics.config.data import DataConfig, NGDataConfig |
| 16 | +from careamics.config.data import DataConfig |
17 | 17 | from careamics.config.lightning.training_config import TrainingConfig |
18 | 18 | from careamics.config.losses.loss_config import LVAELossConfig |
19 | 19 | from careamics.config.noise_model.likelihood_config import ( |
@@ -357,99 +357,6 @@ def _create_microsplit_data_configuration( |
357 | 357 | return MicroSplitDataConfig(**data) |
358 | 358 |
|
359 | 359 |
|
360 | | -def create_ng_data_configuration( |
361 | | - data_type: Literal["array", "tiff", "zarr", "czi", "custom"], |
362 | | - axes: str, |
363 | | - patch_size: Sequence[int], |
364 | | - batch_size: int, |
365 | | - augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None, |
366 | | - channels: Sequence[int] | None = None, |
367 | | - in_memory: bool | None = None, |
368 | | - train_dataloader_params: dict[str, Any] | None = None, |
369 | | - val_dataloader_params: dict[str, Any] | None = None, |
370 | | - pred_dataloader_params: dict[str, Any] | None = None, |
371 | | - seed: int | None = None, |
372 | | -) -> NGDataConfig: |
373 | | - """ |
374 | | - Create a training NGDatasetConfig. |
375 | | -
|
376 | | - Parameters |
377 | | - ---------- |
378 | | - data_type : {"array", "tiff", "zarr", "czi", "custom"} |
379 | | - Type of the data. |
380 | | - axes : str |
381 | | - Axes of the data. |
382 | | - patch_size : list of int |
383 | | - Size of the patches along the spatial dimensions. |
384 | | - batch_size : int |
385 | | - Batch size. |
386 | | - augmentations : list of transforms |
387 | | - List of transforms to apply. |
388 | | - channels : Sequence of int, default=None |
389 | | - List of channels to use. If `None`, all channels are used. |
390 | | - in_memory : bool, default=None |
391 | | - Whether to load all data into memory. This is only supported for 'array', |
392 | | - 'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array', |
393 | | - 'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True` |
394 | | - for `array`. |
395 | | - augmentations : list of transforms or None, default=None |
396 | | - List of transforms to apply. If `None`, default augmentations are applied |
397 | | - (flip in X and Y, rotations by 90 degrees in the XY plane). |
398 | | - train_dataloader_params : dict |
399 | | - Parameters for the training dataloader, see PyTorch notes, by default None. |
400 | | - val_dataloader_params : dict |
401 | | - Parameters for the validation dataloader, see PyTorch notes, by default None. |
402 | | - pred_dataloader_params : dict |
403 | | - Parameters for the test dataloader, see PyTorch notes, by default None. |
404 | | - seed : int, default=None |
405 | | - Random seed for reproducibility. If `None`, no seed is set. |
406 | | -
|
407 | | - Returns |
408 | | - ------- |
409 | | - NGDataConfig |
410 | | - Next-Generation Data model with the specified parameters. |
411 | | - """ |
412 | | - if augmentations is None: |
413 | | - augmentations = _list_spatial_augmentations() |
414 | | - |
415 | | - # data model |
416 | | - data: dict[str, Any] = { |
417 | | - "mode": "training", |
418 | | - "data_type": data_type, |
419 | | - "axes": axes, |
420 | | - "batch_size": batch_size, |
421 | | - "channels": channels, |
422 | | - "transforms": augmentations, |
423 | | - "seed": seed, |
424 | | - } |
425 | | - |
426 | | - if in_memory is not None: |
427 | | - data["in_memory"] = in_memory |
428 | | - |
429 | | - # don't override defaults set in DataConfig class |
430 | | - if train_dataloader_params is not None: |
431 | | - # the presence of `shuffle` key in the dataloader parameters is enforced |
432 | | - # by the NGDataConfig class |
433 | | - if "shuffle" not in train_dataloader_params: |
434 | | - train_dataloader_params["shuffle"] = True |
435 | | - |
436 | | - data["train_dataloader_params"] = train_dataloader_params |
437 | | - |
438 | | - if val_dataloader_params is not None: |
439 | | - data["val_dataloader_params"] = val_dataloader_params |
440 | | - |
441 | | - if pred_dataloader_params is not None: |
442 | | - data["pred_dataloader_params"] = pred_dataloader_params |
443 | | - |
444 | | - # add training patching |
445 | | - data["patching"] = { |
446 | | - "name": "random", |
447 | | - "patch_size": patch_size, |
448 | | - } |
449 | | - |
450 | | - return NGDataConfig(**data) |
451 | | - |
452 | | - |
453 | 360 | def _create_training_configuration( |
454 | 361 | trainer_params: dict, |
455 | 362 | logger: Literal["wandb", "tensorboard", "none"], |
|
0 commit comments