Skip to content

Commit c7c0df0

Browse files
Add model_transform in create supervised trainer (#2848)
* Add model_transform in create supervised trainer * autopep8 fix * Made changes in the model_transform * autopep8 fix * Add test for Supervised trainer output transform * autopep8 fix * changed code formatting * Add necessary changes to tests for model transform * autopep8 fix * Some code formatting changes * autopep8 fix * Made code formatting changes * autopep8 fix * Code formatting changes * Added test for model_output_transform * autopep8 fix * Changed somethng in the test * Updated tests * Update tests/ignite/engine/test_create_supervised.py * Made some changes in the tests * autopep8 fix * Some change * Revert tests and add docstrings * Changed the test version * Some necessary changes --------- Co-authored-by: guptaaryan16 <[email protected]> Co-authored-by: vfdev-5 <[email protected]>
1 parent 34be221 commit c7c0df0

File tree

2 files changed

+99
-49
lines changed

2 files changed

+99
-49
lines changed

ignite/engine/__init__.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def supervised_training_step(
4848
device: Optional[Union[str, torch.device]] = None,
4949
non_blocking: bool = False,
5050
prepare_batch: Callable = _prepare_batch,
51+
model_transform: Callable[[Any], Any] = lambda output: output,
5152
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
5253
gradient_accumulation_steps: int = 1,
5354
) -> Callable:
@@ -64,6 +65,8 @@ def supervised_training_step(
6465
with respect to the host. For other cases, this argument has no effect.
6566
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
6667
tuple of tensors `(batch_x, batch_y)`.
68+
model_transform: function that receives the output from the model and convert it into the form as required
69+
by the loss function
6770
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
6871
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
6972
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
@@ -86,6 +89,8 @@ def supervised_training_step(
8689
.. versionadded:: 0.4.5
8790
.. versionchanged:: 0.4.7
8891
Added Gradient Accumulation.
92+
.. versionchanged:: 0.4.11
93+
Added `model_transform` to transform model's output
8994
"""
9095

9196
if gradient_accumulation_steps <= 0:
@@ -99,7 +104,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
99104
optimizer.zero_grad()
100105
model.train()
101106
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
102-
y_pred = model(x)
107+
output = model(x)
108+
y_pred = model_transform(output)
103109
loss = loss_fn(y_pred, y)
104110
if gradient_accumulation_steps > 1:
105111
loss = loss / gradient_accumulation_steps
@@ -118,6 +124,7 @@ def supervised_training_step_amp(
118124
device: Optional[Union[str, torch.device]] = None,
119125
non_blocking: bool = False,
120126
prepare_batch: Callable = _prepare_batch,
127+
model_transform: Callable[[Any], Any] = lambda output: output,
121128
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
122129
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
123130
gradient_accumulation_steps: int = 1,
@@ -135,6 +142,8 @@ def supervised_training_step_amp(
135142
with respect to the host. For other cases, this argument has no effect.
136143
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
137144
tuple of tensors `(batch_x, batch_y)`.
145+
model_transform: function that receives the output from the model and convert it into the form as required
146+
by the loss function
138147
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
139148
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
140149
scaler: GradScaler instance for gradient scaling. (default: None)
@@ -160,6 +169,8 @@ def supervised_training_step_amp(
160169
.. versionadded:: 0.4.5
161170
.. versionchanged:: 0.4.7
162171
Added Gradient Accumulation.
172+
.. versionchanged:: 0.4.11
173+
Added `model_transform` to transform model's output
163174
"""
164175

165176
try:
@@ -179,7 +190,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
179190
model.train()
180191
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
181192
with autocast(enabled=True):
182-
y_pred = model(x)
193+
output = model(x)
194+
y_pred = model_transform(output)
183195
loss = loss_fn(y_pred, y)
184196
if gradient_accumulation_steps > 1:
185197
loss = loss / gradient_accumulation_steps
@@ -204,6 +216,7 @@ def supervised_training_step_apex(
204216
device: Optional[Union[str, torch.device]] = None,
205217
non_blocking: bool = False,
206218
prepare_batch: Callable = _prepare_batch,
219+
model_transform: Callable[[Any], Any] = lambda output: output,
207220
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
208221
gradient_accumulation_steps: int = 1,
209222
) -> Callable:
@@ -220,6 +233,8 @@ def supervised_training_step_apex(
220233
with respect to the host. For other cases, this argument has no effect.
221234
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
222235
tuple of tensors `(batch_x, batch_y)`.
236+
model_transform: function that receives the output from the model and convert it into the form as required
237+
by the loss function
223238
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
224239
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
225240
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
@@ -243,6 +258,8 @@ def supervised_training_step_apex(
243258
.. versionadded:: 0.4.5
244259
.. versionchanged:: 0.4.7
245260
Added Gradient Accumulation.
261+
.. versionchanged:: 0.4.11
262+
Added `model_transform` to transform model's output
246263
"""
247264

248265
try:
@@ -261,7 +278,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
261278
optimizer.zero_grad()
262279
model.train()
263280
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
264-
y_pred = model(x)
281+
output = model(x)
282+
y_pred = model_transform(output)
265283
loss = loss_fn(y_pred, y)
266284
if gradient_accumulation_steps > 1:
267285
loss = loss / gradient_accumulation_steps
@@ -281,6 +299,7 @@ def supervised_training_step_tpu(
281299
device: Optional[Union[str, torch.device]] = None,
282300
non_blocking: bool = False,
283301
prepare_batch: Callable = _prepare_batch,
302+
model_transform: Callable[[Any], Any] = lambda output: output,
284303
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
285304
gradient_accumulation_steps: int = 1,
286305
) -> Callable:
@@ -297,6 +316,8 @@ def supervised_training_step_tpu(
297316
with respect to the host. For other cases, this argument has no effect.
298317
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
299318
tuple of tensors `(batch_x, batch_y)`.
319+
model_transform: function that receives the output from the model and convert it into the form as required
320+
by the loss function
300321
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
301322
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
302323
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
@@ -320,6 +341,8 @@ def supervised_training_step_tpu(
320341
.. versionadded:: 0.4.5
321342
.. versionchanged:: 0.4.7
322343
Added Gradient Accumulation argument for all supervised training methods.
344+
.. versionchanged:: 0.4.11
345+
Added `model_transform` to transform model's output
323346
"""
324347
try:
325348
import torch_xla.core.xla_model as xm
@@ -337,7 +360,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
337360
optimizer.zero_grad()
338361
model.train()
339362
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
340-
y_pred = model(x)
363+
output = model(x)
364+
y_pred = model_transform(output)
341365
loss = loss_fn(y_pred, y)
342366
if gradient_accumulation_steps > 1:
343367
loss = loss / gradient_accumulation_steps
@@ -384,6 +408,7 @@ def create_supervised_trainer(
384408
device: Optional[Union[str, torch.device]] = None,
385409
non_blocking: bool = False,
386410
prepare_batch: Callable = _prepare_batch,
411+
model_transform: Callable[[Any], Any] = lambda output: output,
387412
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
388413
deterministic: bool = False,
389414
amp_mode: Optional[str] = None,
@@ -403,6 +428,8 @@ def create_supervised_trainer(
403428
with respect to the host. For other cases, this argument has no effect.
404429
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
405430
tuple of tensors `(batch_x, batch_y)`.
431+
model_transform: function that receives the output from the model and convert it into the form as required
432+
by the loss function
406433
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
407434
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
408435
deterministic: if True, returns deterministic engine of type
@@ -496,6 +523,8 @@ def output_transform_fn(x, y, y_pred, loss):
496523
497524
.. versionchanged:: 0.4.7
498525
Added Gradient Accumulation argument for all supervised training methods.
526+
.. versionchanged:: 0.4.11
527+
Added `model_transform` to transform model's output
499528
"""
500529

501530
device_type = device.type if isinstance(device, torch.device) else device
@@ -510,6 +539,7 @@ def output_transform_fn(x, y, y_pred, loss):
510539
device,
511540
non_blocking,
512541
prepare_batch,
542+
model_transform,
513543
output_transform,
514544
_scaler,
515545
gradient_accumulation_steps,
@@ -522,6 +552,7 @@ def output_transform_fn(x, y, y_pred, loss):
522552
device,
523553
non_blocking,
524554
prepare_batch,
555+
model_transform,
525556
output_transform,
526557
gradient_accumulation_steps,
527558
)
@@ -533,6 +564,7 @@ def output_transform_fn(x, y, y_pred, loss):
533564
device,
534565
non_blocking,
535566
prepare_batch,
567+
model_transform,
536568
output_transform,
537569
gradient_accumulation_steps,
538570
)
@@ -544,6 +576,7 @@ def output_transform_fn(x, y, y_pred, loss):
544576
device,
545577
non_blocking,
546578
prepare_batch,
579+
model_transform,
547580
output_transform,
548581
gradient_accumulation_steps,
549582
)
@@ -662,6 +695,7 @@ def create_supervised_evaluator(
662695
device: Optional[Union[str, torch.device]] = None,
663696
non_blocking: bool = False,
664697
prepare_batch: Callable = _prepare_batch,
698+
model_transform: Callable[[Any], Any] = lambda output: output,
665699
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
666700
amp_mode: Optional[str] = None,
667701
) -> Engine:

0 commit comments

Comments
 (0)