@@ -48,6 +48,7 @@ def supervised_training_step(
4848 device : Optional [Union [str , torch .device ]] = None ,
4949 non_blocking : bool = False ,
5050 prepare_batch : Callable = _prepare_batch ,
51+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
5152 output_transform : Callable [[Any , Any , Any , torch .Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
5253 gradient_accumulation_steps : int = 1 ,
5354) -> Callable :
@@ -64,6 +65,8 @@ def supervised_training_step(
6465 with respect to the host. For other cases, this argument has no effect.
6566 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
6667 tuple of tensors `(batch_x, batch_y)`.
68+ model_transform: function that receives the output from the model and convert it into the form as required
69+ by the loss function
6770 output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
6871 to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
6972 gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
@@ -86,6 +89,8 @@ def supervised_training_step(
8689 .. versionadded:: 0.4.5
8790 .. versionchanged:: 0.4.7
8891 Added Gradient Accumulation.
92+ .. versionchanged:: 0.4.11
93+ Added `model_transform` to transform model's output
8994 """
9095
9196 if gradient_accumulation_steps <= 0 :
@@ -99,7 +104,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
99104 optimizer .zero_grad ()
100105 model .train ()
101106 x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
102- y_pred = model (x )
107+ output = model (x )
108+ y_pred = model_transform (output )
103109 loss = loss_fn (y_pred , y )
104110 if gradient_accumulation_steps > 1 :
105111 loss = loss / gradient_accumulation_steps
@@ -118,6 +124,7 @@ def supervised_training_step_amp(
118124 device : Optional [Union [str , torch .device ]] = None ,
119125 non_blocking : bool = False ,
120126 prepare_batch : Callable = _prepare_batch ,
127+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
121128 output_transform : Callable [[Any , Any , Any , torch .Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
122129 scaler : Optional ["torch.cuda.amp.GradScaler" ] = None ,
123130 gradient_accumulation_steps : int = 1 ,
@@ -135,6 +142,8 @@ def supervised_training_step_amp(
135142 with respect to the host. For other cases, this argument has no effect.
136143 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
137144 tuple of tensors `(batch_x, batch_y)`.
145+ model_transform: function that receives the output from the model and convert it into the form as required
146+ by the loss function
138147 output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
139148 to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
140149 scaler: GradScaler instance for gradient scaling. (default: None)
@@ -160,6 +169,8 @@ def supervised_training_step_amp(
160169 .. versionadded:: 0.4.5
161170 .. versionchanged:: 0.4.7
162171 Added Gradient Accumulation.
172+ .. versionchanged:: 0.4.11
173+ Added `model_transform` to transform model's output
163174 """
164175
165176 try :
@@ -179,7 +190,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
179190 model .train ()
180191 x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
181192 with autocast (enabled = True ):
182- y_pred = model (x )
193+ output = model (x )
194+ y_pred = model_transform (output )
183195 loss = loss_fn (y_pred , y )
184196 if gradient_accumulation_steps > 1 :
185197 loss = loss / gradient_accumulation_steps
@@ -204,6 +216,7 @@ def supervised_training_step_apex(
204216 device : Optional [Union [str , torch .device ]] = None ,
205217 non_blocking : bool = False ,
206218 prepare_batch : Callable = _prepare_batch ,
219+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
207220 output_transform : Callable [[Any , Any , Any , torch .Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
208221 gradient_accumulation_steps : int = 1 ,
209222) -> Callable :
@@ -220,6 +233,8 @@ def supervised_training_step_apex(
220233 with respect to the host. For other cases, this argument has no effect.
221234 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
222235 tuple of tensors `(batch_x, batch_y)`.
236+ model_transform: function that receives the output from the model and convert it into the form as required
237+ by the loss function
223238 output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
224239 to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
225240 gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
@@ -243,6 +258,8 @@ def supervised_training_step_apex(
243258 .. versionadded:: 0.4.5
244259 .. versionchanged:: 0.4.7
245260 Added Gradient Accumulation.
261+ .. versionchanged:: 0.4.11
262+ Added `model_transform` to transform model's output
246263 """
247264
248265 try :
@@ -261,7 +278,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
261278 optimizer .zero_grad ()
262279 model .train ()
263280 x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
264- y_pred = model (x )
281+ output = model (x )
282+ y_pred = model_transform (output )
265283 loss = loss_fn (y_pred , y )
266284 if gradient_accumulation_steps > 1 :
267285 loss = loss / gradient_accumulation_steps
@@ -281,6 +299,7 @@ def supervised_training_step_tpu(
281299 device : Optional [Union [str , torch .device ]] = None ,
282300 non_blocking : bool = False ,
283301 prepare_batch : Callable = _prepare_batch ,
302+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
284303 output_transform : Callable [[Any , Any , Any , torch .Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
285304 gradient_accumulation_steps : int = 1 ,
286305) -> Callable :
@@ -297,6 +316,8 @@ def supervised_training_step_tpu(
297316 with respect to the host. For other cases, this argument has no effect.
298317 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
299318 tuple of tensors `(batch_x, batch_y)`.
319+ model_transform: function that receives the output from the model and convert it into the form as required
320+ by the loss function
300321 output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
301322 to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
302323 gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
@@ -320,6 +341,8 @@ def supervised_training_step_tpu(
320341 .. versionadded:: 0.4.5
321342 .. versionchanged:: 0.4.7
322343 Added Gradient Accumulation argument for all supervised training methods.
344+ .. versionchanged:: 0.4.11
345+ Added `model_transform` to transform model's output
323346 """
324347 try :
325348 import torch_xla .core .xla_model as xm
@@ -337,7 +360,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
337360 optimizer .zero_grad ()
338361 model .train ()
339362 x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
340- y_pred = model (x )
363+ output = model (x )
364+ y_pred = model_transform (output )
341365 loss = loss_fn (y_pred , y )
342366 if gradient_accumulation_steps > 1 :
343367 loss = loss / gradient_accumulation_steps
@@ -384,6 +408,7 @@ def create_supervised_trainer(
384408 device : Optional [Union [str , torch .device ]] = None ,
385409 non_blocking : bool = False ,
386410 prepare_batch : Callable = _prepare_batch ,
411+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
387412 output_transform : Callable [[Any , Any , Any , torch .Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
388413 deterministic : bool = False ,
389414 amp_mode : Optional [str ] = None ,
@@ -403,6 +428,8 @@ def create_supervised_trainer(
403428 with respect to the host. For other cases, this argument has no effect.
404429 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
405430 tuple of tensors `(batch_x, batch_y)`.
431+ model_transform: function that receives the output from the model and convert it into the form as required
432+ by the loss function
406433 output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
407434 to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
408435 deterministic: if True, returns deterministic engine of type
@@ -496,6 +523,8 @@ def output_transform_fn(x, y, y_pred, loss):
496523
497524 .. versionchanged:: 0.4.7
498525 Added Gradient Accumulation argument for all supervised training methods.
526+ .. versionchanged:: 0.4.11
527+ Added `model_transform` to transform model's output
499528 """
500529
501530 device_type = device .type if isinstance (device , torch .device ) else device
@@ -510,6 +539,7 @@ def output_transform_fn(x, y, y_pred, loss):
510539 device ,
511540 non_blocking ,
512541 prepare_batch ,
542+ model_transform ,
513543 output_transform ,
514544 _scaler ,
515545 gradient_accumulation_steps ,
@@ -522,6 +552,7 @@ def output_transform_fn(x, y, y_pred, loss):
522552 device ,
523553 non_blocking ,
524554 prepare_batch ,
555+ model_transform ,
525556 output_transform ,
526557 gradient_accumulation_steps ,
527558 )
@@ -533,6 +564,7 @@ def output_transform_fn(x, y, y_pred, loss):
533564 device ,
534565 non_blocking ,
535566 prepare_batch ,
567+ model_transform ,
536568 output_transform ,
537569 gradient_accumulation_steps ,
538570 )
@@ -544,6 +576,7 @@ def output_transform_fn(x, y, y_pred, loss):
544576 device ,
545577 non_blocking ,
546578 prepare_batch ,
579+ model_transform ,
547580 output_transform ,
548581 gradient_accumulation_steps ,
549582 )
@@ -662,6 +695,7 @@ def create_supervised_evaluator(
662695 device : Optional [Union [str , torch .device ]] = None ,
663696 non_blocking : bool = False ,
664697 prepare_batch : Callable = _prepare_batch ,
698+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
665699 output_transform : Callable [[Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
666700 amp_mode : Optional [str ] = None ,
667701) -> Engine :
0 commit comments