@@ -56,6 +56,7 @@ def _inference_transform(model, inputs):
5656 #TODO: I am not sure what is the best way with dealing with the types and
5757 # device when using batched inference. This works for now.
5858 inputs = inputs .type (torch .FloatTensor ).to (next (model .parameters ()).device )
59+
5960 if isinstance (model , cebra .models .ConvolutionalModelMixin ):
6061 # Fully convolutional evaluation, switch (T, C) -> (1, C, T)
6162 inputs = inputs .transpose (1 , 0 ).unsqueeze (0 )
@@ -110,8 +111,6 @@ def _check_indices(indices, inputs):
110111
111112 def _check_batch_size_length (indices_batch , offset ):
112113 batch_size_lenght = indices_batch [1 ] - indices_batch [0 ]
113- print ("batch_size ll" , add_padding , indices , batch_size_lenght ,
114- len (offset ))
115114 if batch_size_lenght <= len (offset ):
116115 raise ValueError (
117116 f"The batch has length { batch_size_lenght } which "
@@ -489,10 +488,8 @@ def _transform(self, model, inputs, offset,
489488 pad_before_transform ) -> torch .Tensor :
490489
491490 if pad_before_transform :
492- inputs = np .pad (inputs , ((offset .left , offset .right - 1 ), (0 , 0 )),
493- mode = "edge" )
494- inputs = torch .from_numpy (inputs )
495-
491+ inputs = F .pad (inputs .T , (offset .left , offset .right - 1 ),
492+ 'replicate' ).T
496493 output = _inference_transform (model , inputs )
497494 return output
498495
0 commit comments