@@ -1091,8 +1091,6 @@ def forward(
10911091        sample_posterior : bool  =  False ,
10921092        return_dict : bool  =  True ,
10931093        generator : Optional [torch .Generator ] =  None ,
1094-         encoder_local_batch_size : int  =  2 ,
1095-         decoder_local_batch_size : int  =  2 ,
10961094    ) ->  Union [DecoderOutput , torch .Tensor ]:
10971095        r""" 
10981096        Args: 
@@ -1103,18 +1101,14 @@ def forward(
11031101                Whether or not to return a [`DecoderOutput`] instead of a plain tuple. 
11041102            generator (`torch.Generator`, *optional*): 
11051103                PyTorch random number generator. 
1106-             encoder_local_batch_size (`int`, *optional*, defaults to 2): 
1107-                 Local batch size for the encoder's batch inference. 
1108-             decoder_local_batch_size (`int`, *optional*, defaults to 2): 
1109-                 Local batch size for the decoder's batch inference. 
11101104        """ 
11111105        x  =  sample 
1112-         posterior  =  self .encode (x ,  local_batch_size = encoder_local_batch_size ).latent_dist 
1106+         posterior  =  self .encode (x ).latent_dist 
11131107        if  sample_posterior :
11141108            z  =  posterior .sample (generator = generator )
11151109        else :
11161110            z  =  posterior .mode ()
1117-         dec  =  self .decode (z ,  local_batch_size = decoder_local_batch_size ).sample 
1111+         dec  =  self .decode (z ).sample 
11181112
11191113        if  not  return_dict :
11201114            return  (dec ,)
0 commit comments