1- from typing import List
1+ from typing import Dict , Optional , Tuple , List
22
33import numpy as np
44import torch # type: ignore[import-untyped]
@@ -99,12 +99,13 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Prediction:
9999 assert len (self .scales ) > 0
100100 assert len (self .models ) > 0
101101 assert len (self .steps ) > 0
102+ prediction = None
102103 x_scaled = downscale (x , self .scales [0 ])
103104 for i , (model , scale , scale_steps ) in enumerate (
104105 zip (self .models , self .scales , self .steps )
105106 ):
106107 steps = scale_steps + np .random .randint (
107- - int (scale_steps * 0.2 ), int (scale_steps * 0.2 )
108+ - int (scale_steps * 0.2 ), int (scale_steps * 0.2 ) + 1
108109 )
109110 if steps <= 0 :
110111 steps = 1
@@ -117,12 +118,13 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Prediction:
117118 self .scales [i + 1 ],
118119 )
119120 # TODO prediction has incorrect number of steps
120- return prediction
121+ return unwrap ( prediction )
121122
122123 def record_steps (self , x : torch .Tensor ):
123124 # TODO let "Prediction" class record steps
124125 step_outputs = []
125126 x_scaled = downscale (x , self .scales [0 ])
127+ prediction = None
126128 for i , (model , scale , scale_steps ) in enumerate (
127129 zip (self .models , self .scales , self .steps )
128130 ):
@@ -132,15 +134,15 @@ def record_steps(self, x: torch.Tensor):
132134 step_outputs .append (upscale (prediction .output_image , scale ))
133135 x_in = prediction .output_image
134136 if i < len (self .scales ) - 1 :
135- x_scaled = upscale (prediction .output_image , scale / self .scales [i + 1 ])
137+ x_scaled = upscale (unwrap ( prediction ) .output_image , scale / self .scales [i + 1 ])
136138 # replace input with downscaled variant of original image
137139 x_scaled [:, : model .num_image_channels , :, :] = downscale (
138140 x [:, : model .num_image_channels , :, :],
139141 self .scales [i + 1 ],
140142 )
141143 return step_outputs
142144
143- def validate (self , image : torch .Tensor , label : torch .Tensor , steps : int = 1 ):
145+ def validate (self , image : torch .Tensor , label : torch .Tensor , steps : int = 1 ) -> Optional [ Tuple [ Dict [ str , float ], Prediction ]] :
144146 """
145147 Validation method.
146148
@@ -175,4 +177,4 @@ def validate(self, image: torch.Tensor, label: torch.Tensor, steps: int = 1):
175177 image [:, : model .num_image_channels , :, :],
176178 self .scales [i + 1 ],
177179 )
178- return metrics , prediction
180+ return unwrap ( metrics ), unwrap ( prediction )
0 commit comments