Skip to content

Commit ea2888a

Browse files
authored
Renaming all instances of train_step to train_batch. Introducing validate|test|infer_batch to all existing models. Updated engine creation code to look for appropriate model methods. (#664)
1 parent 51157ed commit ea2888a

14 files changed

+558
-42
lines changed

docs/data_flow.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The ``prepare_inputs`` function is responsible for taking in the output of the `
2424
Model Input and Output Pipeline
2525
-------------------------------
2626

27-
The data is then either sent to ``pytorch`` ignite for training or to ``onnx`` for inference. In the case of training, the data is converted into a tensor and passed into the model's ``train_step`` function. The model processes the data and returns the output predictions. If the user wishes to perform data augmentations, these can be set up in the model's ``train_step`` function as well.
27+
The data is then either sent to ``pytorch`` ignite for training or to ``onnx`` for inference. In the case of training, the data is converted into a tensor and passed into the model's ``train_batch`` function. The model processes the data and returns the output predictions. If the user wishes to perform data augmentations, these can be set up in the model's ``train_batch`` function as well.
2828

2929
In the ``onnx`` case, the data remains a numpy array throughout the model evaluation and to the result. Both paths result in the output being a numpy array.
3030

docs/external_libraries.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Defining a Model
4343
----------------
4444

4545
Models must be written as a subclasses of ``torch.nn.Module``, use pytorch for computation, and
46-
be decorated with ``@hyrax_model``. Models must minimally define ``__init__``, ``forward``, and ``train_step``
46+
be decorated with ``@hyrax_model``. Models must minimally define ``__init__``, ``prepare_inputs``, ``infer_batch``, and ``train_batch``
4747
methods.
4848

4949
In order to get the ``@hyrax_model`` decorator you can import it with ``from hyrax.models import hyrax_model``.
@@ -59,21 +59,21 @@ to allow your model class to adjust architecture or check that the provided data
5959
the first iterable axis of the numpy array.
6060

6161

62-
``forward(self, x)``
62+
``infer_batch(self, x)``
6363
....................
6464
Hyrax calls this function, which evaluates your model on a single input ``x``. ``x`` is guaranteed to be a numpy array with
6565
the shape passed to ``__init__``.
6666

67-
``forward()`` should return a numpy array that is the output of your model.
67+
``infer_batch(self, x)`` should return a numpy array that is the output of your model.
6868

6969

70-
``train_step(self, batch)``
70+
``train_batch(self, batch)``
7171
...........................
7272
This is called several times every training epoch with a batch of input numpy arrays for your model, and is the
7373
inner training loop for your model. This is where you compute loss, perform back propagation, etc depending on
7474
how your model is trained.
7575

76-
``train_step`` returns a dictionary with a "loss" key who's value is a list of loss values for the individual
76+
``train_batch`` returns a dictionary with a "loss" key who's value is a list of loss values for the individual
7777
items in the batch. This loss is logged to MLflow and tensorboard.
7878

7979
Optional Methods

src/hyrax/models/hsc_autoencoder.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ def forward(self, x):
4343
decoded = self.decoder(encoded)
4444
return decoded
4545

46-
def train_step(self, batch):
46+
def train_batch(self, batch):
4747
"""
4848
This function contains the logic for a single training step. i.e. the
4949
contents of the inner loop of a ML training process.
5050
5151
Parameters
5252
----------
5353
batch : tuple
54-
A tuple containing the two values the loss function
54+
A tuple containing the input data for the current batch, possibly
55+
with labels that are ignored.
5556
5657
Returns
5758
-------
@@ -68,3 +69,74 @@ def train_step(self, batch):
6869
self.optimizer.step()
6970

7071
return {"loss": loss.item()}
72+
73+
def validate_batch(self, batch):
74+
"""
75+
This function contains the logic for a single validation step that will
76+
process a single batch of data. i.e. the contents of the inner loop of a
77+
ML validation process.
78+
79+
Parameters
80+
----------
81+
batch : tuple
82+
A tuple containing the input data for the current batch, possibly
83+
with labels that are ignored.
84+
85+
Returns
86+
-------
87+
Current loss value : dict
88+
Dictionary containing the loss value for the current batch.
89+
"""
90+
91+
data = batch[0]
92+
93+
decoded = self.forward(data)
94+
loss = self.criterion(decoded, data)
95+
96+
return {"loss": loss.item()}
97+
98+
def test_batch(self, batch):
99+
"""
100+
This function contains the logic for a single testing step that will
101+
process a single batch of data. i.e. the contents of the inner loop of a
102+
ML testing process. In this case, it is identical to `validate_batch`.
103+
104+
Parameters
105+
----------
106+
batch : tuple
107+
A tuple containing the input data for the current batch, possibly
108+
with labels that are ignored.
109+
110+
Returns
111+
-------
112+
Current loss value : dict
113+
Dictionary containing the loss value for the current batch.
114+
"""
115+
116+
data = batch[0]
117+
118+
decoded = self.forward(data)
119+
loss = self.criterion(decoded, data)
120+
121+
return {"loss": loss.item()}
122+
123+
def infer_batch(self, batch):
124+
"""
125+
This function contains the logic for a single inference step that will
126+
process a single batch of data. i.e. the contents of the inner loop of a
127+
ML inference process.
128+
129+
Parameters
130+
----------
131+
batch : tuple
132+
A tuple containing the input data for the current batch, possibly
133+
with labels that are ignored.
134+
135+
Returns
136+
-------
137+
Reconstructed outputs : torch.Tensor
138+
The reconstructed outputs from the autoencoder.
139+
"""
140+
141+
data = batch[0]
142+
return self.forward(data)

src/hyrax/models/hsc_dcae.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ def forward(self, x):
7171

7272
return x4
7373

74-
def train_step(self, batch):
74+
def train_batch(self, batch):
7575
"""This function contains the logic for a single training step. i.e. the
7676
contents of the inner loop of a ML training process.
7777
7878
Parameters
7979
----------
8080
batch : tuple
81-
A tuple containing the two values the loss function
81+
A tuple containing the input data for the current batch, possibly
82+
with labels that are ignored.
8283
8384
Returns
8485
-------
@@ -107,3 +108,93 @@ def train_step(self, batch):
107108
self.optimizer.step()
108109

109110
return {"loss": loss.item()}
111+
112+
def validate_batch(self, batch):
113+
"""This function contains the logic for a single validation step that will
114+
process a single batch of data. i.e. the contents of the inner loop of a
115+
ML validation process.
116+
117+
Parameters
118+
----------
119+
batch : tuple
120+
A tuple containing the input data for the current batch, possibly
121+
with labels that are ignored.
122+
123+
Returns
124+
-------
125+
Current loss value : dict
126+
Dictionary containing the loss value for the current batch.
127+
"""
128+
129+
# Dropping labels if present
130+
data = batch[0] if isinstance(batch, tuple) else batch
131+
132+
# Encoder with skip connections
133+
x1 = self.activation(self.encoder1(data))
134+
x2 = self.activation(self.encoder2(self.pool(x1)))
135+
x3 = self.activation(self.encoder3(self.pool(x2)))
136+
x4 = self.activation(self.encoder4(self.pool(x3)))
137+
138+
# Decoder with skip connections
139+
x = self.activation(self.decoder4(x4) + x3)
140+
x = self.activation(self.decoder3(x) + x2)
141+
x = self.activation(self.decoder2(x) + x1)
142+
decoded = self.final_activation(self.decoder1(x))
143+
144+
loss = self.criterion(decoded, data)
145+
146+
return {"loss": loss.item()}
147+
148+
def test_batch(self, batch):
149+
"""This function contains the logic for a single testing step that will
150+
process a single batch of data. i.e. the contents of the inner loop of a
151+
ML testing process. In this case, it is identical to `validate_batch`.
152+
153+
Parameters
154+
----------
155+
batch : tuple
156+
A tuple containing the input data for the current batch, possibly
157+
with labels that are ignored.
158+
159+
Returns
160+
-------
161+
Current loss value : dict
162+
Dictionary containing the loss value for the current batch.
163+
"""
164+
165+
# Dropping labels if present
166+
data = batch[0] if isinstance(batch, tuple) else batch
167+
168+
# Encoder with skip connections
169+
x1 = self.activation(self.encoder1(data))
170+
x2 = self.activation(self.encoder2(self.pool(x1)))
171+
x3 = self.activation(self.encoder3(self.pool(x2)))
172+
x4 = self.activation(self.encoder4(self.pool(x3)))
173+
174+
# Decoder with skip connections
175+
x = self.activation(self.decoder4(x4) + x3)
176+
x = self.activation(self.decoder3(x) + x2)
177+
x = self.activation(self.decoder2(x) + x1)
178+
decoded = self.final_activation(self.decoder1(x))
179+
180+
loss = self.criterion(decoded, data)
181+
182+
return {"loss": loss.item()}
183+
184+
def infer_batch(self, batch):
185+
"""This function contains the logic for a single inference step that will
186+
process a single batch of data. i.e. the contents of the inner loop of a
187+
ML inference process.
188+
189+
Parameters
190+
----------
191+
batch : tuple
192+
A tuple containing the input data for the current batch, possibly
193+
with labels that are ignored.
194+
195+
Returns
196+
-------
197+
Reconstructed outputs : torch.Tensor
198+
The reconstructed outputs from the autoencoder.
199+
"""
200+
return self.forward(batch)

src/hyrax/models/hyrax_autoencoder.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class HyraxAutoencoder(nn.Module):
2020
This example model is taken from this
2121
`autoenocoder tutorial <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html>`_
2222
23-
The train function has been converted into train_step for use with pytorch-ignite.
23+
The train function has been converted into train_batch for use with pytorch-ignite.
2424
"""
2525

2626
def __init__(self, config, data_sample=None):
@@ -109,14 +109,15 @@ def _eval_decoder(self, x):
109109
def forward(self, batch):
110110
return self._eval_encoder(batch)
111111

112-
def train_step(self, batch):
112+
def train_batch(self, batch):
113113
"""This function contains the logic for a single training step. i.e. the
114114
contents of the inner loop of a ML training process.
115115
116116
Parameters
117117
----------
118118
batch : tuple
119-
A tuple containing the inputs and labels for the current batch.
119+
A tuple containing the input data for the current batch, possibly
120+
with labels that are ignored.
120121
121122
Returns
122123
-------
@@ -136,6 +137,69 @@ def train_step(self, batch):
136137

137138
return {"loss": loss.item()}
138139

140+
def validate_batch(self, batch):
141+
"""This function contains the logic for a single validation step that will
142+
process a single batch of data. i.e. the contents of the inner loop of a
143+
ML validation process.
144+
145+
Parameters
146+
----------
147+
batch : tuple
148+
A tuple containing the input data for the current batch, possibly
149+
with labels that are ignored.
150+
151+
Returns
152+
-------
153+
Current loss value : dict
154+
Dictionary containing the loss value for the current batch.
155+
"""
156+
z = self._eval_encoder(batch)
157+
x_hat = self._eval_decoder(z)
158+
loss = F.mse_loss(batch, x_hat, reduction="none")
159+
loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
160+
161+
return {"loss": loss.item()}
162+
163+
def test_batch(self, batch):
164+
"""This function contains the logic for a single testing step that will
165+
process a single batch of data. i.e. the contents of the inner loop of a
166+
ML testing process. In this case, it is identical to `validate_batch`.
167+
168+
Parameters
169+
----------
170+
batch : tuple
171+
A tuple containing the input data for the current batch, possibly
172+
with labels that are ignored.
173+
174+
Returns
175+
-------
176+
Current loss value : dict
177+
Dictionary containing the loss value for the current batch.
178+
"""
179+
z = self._eval_encoder(batch)
180+
x_hat = self._eval_decoder(z)
181+
loss = F.mse_loss(batch, x_hat, reduction="none")
182+
loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
183+
184+
return {"loss": loss.item()}
185+
186+
def infer_batch(self, batch):
187+
"""This function contains the logic for a single inference step. i.e. the
188+
contents of the inner loop of a ML inference process.
189+
190+
Parameters
191+
----------
192+
batch : tuple
193+
A tuple containing the input data for the current batch, possibly
194+
with labels that are ignored.
195+
196+
Returns
197+
-------
198+
Reconstructed inputs : torch.Tensor
199+
The reconstructed inputs from the autoencoder.
200+
"""
201+
return self.forward(batch)
202+
139203
@staticmethod
140204
def prepare_inputs(data_dict) -> tuple:
141205
"""This function converts structured data to the input tensor we need to run

0 commit comments

Comments
 (0)