@@ -98,7 +98,7 @@ def __init__(
9898 self .gaussian_conditional = gaussian_conditional or GaussianConditional (None )
9999 self .entropy_parameters = entropy_parameters or nn .Identity ()
100100 self .context_prediction = context_prediction or MaskedConv2d ()
101- self .kernel_size = _reduce_seq (self .context_prediction .kernel_size )
101+ self .kernel_size = _to_single (self .context_prediction .kernel_size )
102102 self .padding = (self .kernel_size - 1 ) // 2
103103
104104 def forward (self , y : Tensor , params : Tensor ) -> Dict [str , Any ]:
@@ -113,8 +113,11 @@ def forward(self, y: Tensor, params: Tensor) -> Dict[str, Any]:
113113
114114 def compress (self , y : Tensor , ctx_params : Tensor ) -> Dict [str , Any ]:
115115 n , _ , y_height , y_width = y .shape
116- ds = [
117- self ._compress_single (
116+ ds = []
117+ for i in range (n ):
118+ encoder = BufferedRansEncoder ()
119+ y_hat = raster_scan_compress_single_stream (
120+ encoder = encoder ,
118121 y = y [i : i + 1 , :, :, :],
119122 params = ctx_params [i : i + 1 , :, :, :],
120123 gaussian_conditional = self .gaussian_conditional ,
@@ -126,16 +129,10 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
126129 kernel_size = self .kernel_size ,
127130 merge = self .merge ,
128131 )
129- for i in range ( n )
130- ]
132+ y_strings = encoder . flush ( )
133+ ds . append ({ "strings" : [ y_strings ], "y_hat" : y_hat . squeeze ( 0 )})
131134 return {** default_collate (ds ), "shape" : y .shape [2 :4 ]}
132135
133- def _compress_single (self , ** kwargs ):
134- encoder = BufferedRansEncoder ()
135- y_hat = raster_scan_compress_single_stream (encoder = encoder , ** kwargs )
136- y_strings = encoder .flush ()
137- return {"strings" : [y_strings ], "y_hat" : y_hat .squeeze (0 )}
138-
139136 def decompress (
140137 self ,
141138 strings : List [List [bytes ]],
@@ -145,9 +142,12 @@ def decompress(
145142 ) -> Dict [str , Any ]:
146143 (y_strings ,) = strings
147144 y_height , y_width = shape
148- ds = [
149- self ._decompress_single (
150- y_string = y_strings [i ],
145+ ds = []
146+ for i in range (len (y_strings )):
147+ decoder = RansDecoder ()
148+ decoder .set_stream (y_strings [i ])
149+ y_hat = raster_scan_decompress_single_stream (
150+ decoder = decoder ,
151151 params = ctx_params [i : i + 1 , :, :, :],
152152 gaussian_conditional = self .gaussian_conditional ,
153153 entropy_parameters = self .entropy_parameters ,
@@ -159,16 +159,9 @@ def decompress(
159159 device = ctx_params .device ,
160160 merge = self .merge ,
161161 )
162- for i in range (len (y_strings ))
163- ]
162+ ds .append ({"y_hat" : y_hat .squeeze (0 )})
164163 return default_collate (ds )
165164
166- def _decompress_single (self , y_string , ** kwargs ):
167- decoder = RansDecoder ()
168- decoder .set_stream (y_string )
169- y_hat = raster_scan_decompress_single_stream (decoder = decoder , ** kwargs )
170- return {"y_hat" : y_hat .squeeze (0 )}
171-
172165 @staticmethod
173166 def merge (* args ):
174167 return torch .cat (args , dim = 1 )
@@ -312,12 +305,16 @@ def _pad_2d(x: Tensor, padding: int) -> Tensor:
312305 return F .pad (x , (padding , padding , padding , padding ))
313306
314307
315- def _reduce_seq (xs ):
308+ def _to_single (xs ):
316309 assert all (x == xs [0 ] for x in xs )
317310 return xs [0 ]
318311
319312
320313def default_collate (batch : List [Dict [K , V ]]) -> Dict [K , List [V ]]:
314+ """Combines a list of dictionaries into a single dictionary.
315+
316+ Workaround to ``torch.utils.data.default_collate`` bug in PyTorch 2.0.0.
317+ """
321318 if not isinstance (batch , list ) or any (not isinstance (d , dict ) for d in batch ):
322319 raise NotImplementedError
323320
0 commit comments