@@ -97,7 +97,7 @@ def convert_pt_to_tf(y: List[Dict[str, np.ndarray]], height: int, width: int) ->
9797def cast_inputs_to_pt (
9898 x : Union [np .ndarray , "torch.Tensor" ],
9999 y : Optional [List [Dict [str , Union [np .ndarray , "torch.Tensor" ]]]] = None ,
100- ) -> Tuple ["torch.Tensor" , List [Dict [str , "torch.Tensor" ]]]:
100+ ) -> Tuple ["torch.Tensor" , Optional [ List [Dict [str , "torch.Tensor" ] ]]]:
101101 """
102102 Cast object detection inputs `(x, y)` to PyTorch tensors.
103103
@@ -117,25 +117,43 @@ def cast_inputs_to_pt(
117117 else :
118118 x_tensor = x
119119
120+ y_tensor : Optional [List [Dict [str , torch .Tensor ]]] = None
121+
120122 # Convert labels into tensor
121- if y is not None and isinstance (y , list ) and isinstance ( y [ 0 ][ "boxes" ], np . ndarray ):
123+ if isinstance (y , list ):
122124 y_tensor = []
123125 for y_i in y :
124- y_t = {
125- "boxes" : torch .from_numpy (y_i ["boxes" ]).to (dtype = torch .float32 ),
126- "labels" : torch .from_numpy (y_i ["labels" ]).to (dtype = torch .int64 ),
127- }
126+ y_t = {}
127+
128+ if isinstance (y_i ["boxes" ], np .ndarray ):
129+ y_t ["boxes" ] = torch .from_numpy (y_i ["boxes" ]).to (dtype = torch .float32 )
130+ else :
131+ y_t ["boxes" ] = y_i ["boxes" ]
132+
133+ if isinstance (y_i ["labels" ], np .ndarray ):
134+ y_t ["labels" ] = torch .from_numpy (y_i ["labels" ]).to (dtype = torch .int64 )
135+ else :
136+ y_t ["labels" ] = y_i ["labels" ]
137+
128138 if "masks" in y_i :
129- y_t ["masks" ] = torch .from_numpy (y_i ["masks" ]).to (dtype = torch .uint8 )
139+ if isinstance (y_i ["masks" ], np .ndarray ):
140+ y_t ["masks" ] = torch .from_numpy (y_i ["masks" ]).to (dtype = torch .uint8 )
141+ else :
142+ y_t ["masks" ] = y_i ["masks" ]
143+
130144 y_tensor .append (y_t )
131- elif y is not None and isinstance (y , dict ):
145+ elif isinstance (y , dict ):
132146 y_tensor = []
133- for i in range (y ["boxes" ]. shape [ 0 ] ):
147+ for i in range (len ( y ["boxes" ]) ):
134148 y_t = {}
149+
135150 y_t ["boxes" ] = y ["boxes" ][i ]
136151 y_t ["labels" ] = y ["labels" ][i ]
152+ if "masks" in y :
153+ y_t ["masks" ] = y ["masks" ][i ]
154+
137155 y_tensor .append (y_t )
138156 else :
139- y_tensor = y # type: ignore
157+ y_tensor = y
140158
141159 return x_tensor , y_tensor
0 commit comments