Skip to content

Commit d22bb50

Browse files
committed
add comments: models
1 parent 2215995 commit d22bb50

15 files changed

+419
-163
lines changed

data/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""This package includes all the modules related to data loading and preprocessing
22
3-
To add a custom dataset class called dummy, you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
3+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
44
You need to implement four functions:
5-
-- <__init__> (initialize the class, first call BaseDataset.__init__(self, opt))
6-
-- <__len__> (return the size of dataset)
7-
-- <__getitem__> (get a data point)
8-
-- (optionally) <modify_commandline_options> (add dataset-specific options and set default options).
5+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6+
-- <__len__>: return the size of dataset.
7+
-- <__getitem__>: get a data point from data loader.
8+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9+
910
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
1011
See our template dataset class 'template_dataset.py' for more details.
1112
"""
@@ -15,7 +16,7 @@
1516

1617

1718
def find_dataset_using_name(dataset_name):
18-
"""Import the module "data/[datasetname]_dataset.py" given the option '--dataset_mode [datasetname].
19+
"""Import the module "data/[dataset_name]_dataset.py".
1920
2021
In the file, the class called DatasetNameDataset() will
2122
be instantiated. It has to be a subclass of BaseDataset,
@@ -44,14 +45,14 @@ def get_option_setter(dataset_name):
4445

4546

4647
def create_dataset(opt):
47-
"""Create dataset given the option.
48+
"""Create a dataset given the option.
4849
49-
This function warps the class CustomDatasetDataLoader.
50-
This is the main interface called by train.py and test.py.
50+
This function wraps the class CustomDatasetDataLoader.
51+
This is the main interface between this package and 'train.py'/'test.py'
5152
5253
Example:
53-
from data import create_dataset
54-
dataset = create_dataset(opt)
54+
>>> from data import create_dataset
55+
>>> dataset = create_dataset(opt)
5556
"""
5657
data_loader = CustomDatasetDataLoader(opt)
5758
dataset = data_loader.load_data()

data/base_dataset.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111
class BaseDataset(data.Dataset, ABC):
1212
"""This class is an abstract base class (ABC) for datasets.
1313
14-
To create a subclass, you need to implement four functions:
15-
-- <__init__> (initialize the class, first call BaseDataset.__init__(self, opt))
16-
-- <__len__> (return the size of dataset)
17-
-- <__getitem__> (get a data point)
18-
-- (optionally) <modify_commandline_options> (add dataset-specific options and set default options).
14+
To create a subclass, you need to implement the following four functions:
15+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
16+
-- <__len__>: return the size of dataset.
17+
-- <__getitem__>: get a data point.
18+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
1919
"""
2020

2121
def __init__(self, opt):
2222
"""Initialize the class; save the options in the class
2323
2424
Parameters:
25-
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
25+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
2626
"""
2727
self.opt = opt
2828
self.root = opt.dataroot
@@ -32,8 +32,8 @@ def modify_commandline_options(parser, is_train):
3232
"""Add new dataset-specific options, and rewrite default values for existing options.
3333
3434
Parameters:
35-
parser -- original option parser
36-
is_train -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
35+
parser -- original option parser
36+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
3737
3838
Returns:
3939
the modified parser.

data/colorization_dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ class ColorizationDataset(BaseDataset):
1616
def modify_commandline_options(parser, is_train):
1717
"""Add new dataset-specific options, and rewrite default values for existing options.
1818
19+
Parameters:
20+
parser -- original option parser
21+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
22+
23+
Returns:
24+
the modified parser.
25+
1926
By default, the number of channels for input image is 1 (L) and
20-
the nubmer of channels for output image is 2 (ab). The direction is from A to B
27+
the nubmer of channels for output image is 2 (ab). The direction is from A to B
2128
"""
2229
parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
2330
return parser

data/single_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, opt):
1313
"""Initialize this dataset class.
1414
1515
Parameters:
16-
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
16+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
1717
"""
1818
BaseDataset.__init__(self, opt)
1919
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))

data/template_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def modify_commandline_options(parser, is_train):
2323
"""Add new dataset-specific options, and rewrite default values for existing options.
2424
2525
Parameters:
26-
parser -- original option parser
27-
is_train -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
26+
parser -- original option parser
27+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
2828
2929
Returns:
3030
the modified parser.
@@ -37,7 +37,7 @@ def __init__(self, opt):
3737
"""Initialize this dataset class.
3838
3939
Parameters:
40-
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
40+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
4141
4242
A few things can be done here.
4343
- save the options (have been done in BaseDataset)

data/unaligned_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, opt):
2020
"""Initialize this dataset class.
2121
2222
Parameters:
23-
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
23+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
2424
"""
2525
BaseDataset.__init__(self, opt)
2626
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
@@ -40,13 +40,13 @@ def __getitem__(self, index):
4040
"""Return a data point and its metadata information.
4141
4242
Parameters:
43-
index - - a random integer for data indexing
43+
index (int) -- a random integer for data indexing
4444
4545
Returns a dictionary that contains A, B, A_paths and B_paths
46-
A(tensor) - - an image in the input domain
47-
B(tensor) - - its corresponding image in the target domain
48-
A_paths(str) - - image paths
49-
B_paths(str) - - image paths
46+
A (tensor) -- an image in the input domain
47+
B (tensor) -- its corresponding image in the target domain
48+
A_paths (str) -- image paths
49+
B_paths (str) -- image paths
5050
"""
5151
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
5252
if self.opt.serial_batches: # make sure index is within then range

docs/overview.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ To help users better understand and use our codebase, we briefly overview the fu
2424
* [base_model.py](../models/base_model.py) implements an abstract base class ([ABC](https://docs.python.org/3/library/abc.html)) for models. It also includes commonly used helper functions (e.g., `setup`, `test`, `update_learning_rate`, `save_networks`, `load_networks`), which can be later used in subclasses.
2525
* [template_model.py](../models/template_model.py) provides a model template with detailed documentation. Check out this file if you plan to implement your own model.
2626
* [pix2pix_model.py](../models/pix2pix_model.py) implements the pix2pix [model](https://phillipi.github.io/pix2pix/), for learning a mapping from input images to output images given paired data. The model training requires `--dataset_mode aligned` dataset. By default, it uses a `--netG unet256` [U-Net](https://arxiv.org/pdf/1505.04597.pdf) generator, a `--netD basic` discriminator (PatchGAN), and a `--gan_mode vanilla` GAN loss (standard cross-entropy objective).
27-
* [colorization_model.py](../models/colorization_model.py) implements a subclass of `Pix2PixModel` for image colorization (black & white image to colorful image). The model training requires `-dataset_model colorization` dataset. It trains a pix2pix model, mapping from L channel to ab channel in [Lab](https://en.wikipedia.org/wiki/CIELAB_color_space) color space. By default, the model will automatically set `--input_nc 1` and `--output_nc 2`.
28-
* [cycle_gan_model.py](../models/cycle_gan_model.py) implements the CycleGAN [model](https://junyanz.github.io/CycleGAN/), for learning image-to-image translation without paired data. The model training requires `--dataset_mode unaligned` dataset. By default, it uses a `--netG resnet_9blocks` ResNet generator, a `--netD basic` discrimiator (PatchGAN introduced by pix2pix), and a least-square GANs [objective](https://arxiv.org/abs/1611.04076) (`--gan_mode lsgan`).
27+
* [colorization_model.py](../models/colorization_model.py) implements a subclass of `Pix2PixModel` for image colorization (black & white image to colorful image). The model training requires `-dataset_model colorization` dataset. It trains a pix2pix model, mapping from L channel to ab channels in [Lab](https://en.wikipedia.org/wiki/CIELAB_color_space) color space. By default, the `colorization` dataset will automatically set `--input_nc 1` and `--output_nc 2`.
28+
* [cycle_gan_model.py](../models/cycle_gan_model.py) implements the CycleGAN [model](https://junyanz.github.io/CycleGAN/), for learning image-to-image translation without paired data. The model training requires `--dataset_mode unaligned` dataset. By default, it uses a `--netG resnet_9blocks` ResNet generator, a `--netD basic` discriminator (PatchGAN introduced by pix2pix), and a least-square GANs [objective](https://arxiv.org/abs/1611.04076) (`--gan_mode lsgan`).
2929
* [networks.py](../models/networks.py) module implements network architectures (both generators and discriminators), as well as normalization layers, initialization methods, optimization scheduler (i.e., learning rate policy), and GAN objective function (`vanilla`, `lsgan`, `wgangp`).
30-
* [test_model.py](../models/test_model.py) implements a model that can be used to generate CycleGAN results for only one direction. This option will automatically set `--dataset_mode single`, which only loads the images from one set. See the test [instruction](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#apply-a-pre-trained-model-cyclegan) for more details.
30+
* [test_model.py](../models/test_model.py) implements a model that can be used to generate CycleGAN results for only one direction. This model will automatically set `--dataset_mode single`, which only loads the images from one set. See the test [instruction](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#apply-a-pre-trained-model-cyclegan) for more details.
3131

3232
[options](../options) directory includes our option modules: training options, test options, and basic options (used in both training and test). `TrainOptions` and `TestOptions` are both subclasses of `BaseOptions`. They will reuse the options defined in `BaseOptions`.
3333
* [\_\_init\_\_.py](../options/__init__.py) is required to make Python treat the directory `options` as containing packages,

models/__init__.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
1+
"""This package contains modules related to objective functions, optimizations, and network architectures.
2+
3+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4+
You need to implement the following five functions:
5+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6+
-- <set_input>: unpack data from dataset and apply preprocessing.
7+
-- <forward>: produce intermediate results.
8+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
9+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10+
11+
In the function <__init__>, you need to define four lists:
12+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
13+
-- self.model_names (str list): specify the images that you want to display and save.
14+
-- self.visual_names (str list): define networks used in our training.
15+
-- self.optimizers (optimzier list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
16+
17+
Now you can use the model class by specifying flag '--model dummy'.
18+
See our template model class 'template_model.py' for an example.
19+
"""
20+
121
import importlib
222
from models.base_model import BaseModel
323

424

525
def find_model_using_name(model_name):
6-
# Given the option --model [modelname],
7-
# the file "models/modelname_model.py"
8-
# will be imported.
26+
"""Import the module "models/[model_name]_model.py".
27+
28+
In the file, the class called DatasetNameModel() will
29+
be instantiated. It has to be a subclass of BaseModel,
30+
and it is case-insensitive.
31+
"""
932
model_filename = "models." + model_name + "_model"
1033
modellib = importlib.import_module(model_filename)
11-
12-
# In the file, the class called ModelNameModel() will
13-
# be instantiated. It has to be a subclass of BaseModel,
14-
# and it is case-insensitive.
1534
model = None
1635
target_model_name = model_name.replace('_', '') + 'model'
1736
for name, cls in modellib.__dict__.items():
@@ -27,11 +46,21 @@ def find_model_using_name(model_name):
2746

2847

2948
def get_option_setter(model_name):
49+
"""Return the static method <modify_commandline_options> of the model class."""
3050
model_class = find_model_using_name(model_name)
3151
return model_class.modify_commandline_options
3252

3353

3454
def create_model(opt):
55+
"""Create a model given the option.
56+
57+
This function warps the class CustomDatasetDataLoader.
58+
This is the main interface between this package and 'train.py'/'test.py'
59+
60+
Example:
61+
>>> from models import create_model
62+
>>> model = create_model(opt)
63+
"""
3564
model = find_model_using_name(opt.model)
3665
instance = model(opt)
3766
print("model [%s] was created" % type(instance).__name__)

0 commit comments

Comments
 (0)