Skip to content

Commit ddc09f7

Browse files
committed
fix type checking for object detection utils
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent ada0830 commit ddc09f7

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

art/estimators/object_detection/utils.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def convert_pt_to_tf(y: List[Dict[str, np.ndarray]], height: int, width: int) ->
9797
def cast_inputs_to_pt(
9898
x: Union[np.ndarray, "torch.Tensor"],
9999
y: Optional[List[Dict[str, Union[np.ndarray, "torch.Tensor"]]]] = None,
100-
) -> Tuple["torch.Tensor", List[Dict[str, "torch.Tensor"]]]:
100+
) -> Tuple["torch.Tensor", Optional[List[Dict[str, "torch.Tensor"]]]]:
101101
"""
102102
Cast object detection inputs `(x, y)` to PyTorch tensors.
103103
@@ -117,25 +117,43 @@ def cast_inputs_to_pt(
117117
else:
118118
x_tensor = x
119119

120+
y_tensor: Optional[List[Dict[str, torch.Tensor]]] = None
121+
120122
# Convert labels into tensor
121-
if y is not None and isinstance(y, list) and isinstance(y[0]["boxes"], np.ndarray):
123+
if isinstance(y, list):
122124
y_tensor = []
123125
for y_i in y:
124-
y_t = {
125-
"boxes": torch.from_numpy(y_i["boxes"]).to(dtype=torch.float32),
126-
"labels": torch.from_numpy(y_i["labels"]).to(dtype=torch.int64),
127-
}
126+
y_t = {}
127+
128+
if isinstance(y_i["boxes"], np.ndarray):
129+
y_t["boxes"] = torch.from_numpy(y_i["boxes"]).to(dtype=torch.float32)
130+
else:
131+
y_t["boxes"] = y_i["boxes"]
132+
133+
if isinstance(y_i["labels"], np.ndarray):
134+
y_t["labels"] = torch.from_numpy(y_i["labels"]).to(dtype=torch.int64)
135+
else:
136+
y_t["labels"] = y_i["labels"]
137+
128138
if "masks" in y_i:
129-
y_t["masks"] = torch.from_numpy(y_i["masks"]).to(dtype=torch.uint8)
139+
if isinstance(y_i["masks"], np.ndarray):
140+
y_t["masks"] = torch.from_numpy(y_i["masks"]).to(dtype=torch.uint8)
141+
else:
142+
y_t["masks"] = y_i["masks"]
143+
130144
y_tensor.append(y_t)
131-
elif y is not None and isinstance(y, dict):
145+
elif isinstance(y, dict):
132146
y_tensor = []
133-
for i in range(y["boxes"].shape[0]):
147+
for i in range(len(y["boxes"])):
134148
y_t = {}
149+
135150
y_t["boxes"] = y["boxes"][i]
136151
y_t["labels"] = y["labels"][i]
152+
if "masks" in y:
153+
y_t["masks"] = y["masks"][i]
154+
137155
y_tensor.append(y_t)
138156
else:
139-
y_tensor = y # type: ignore
157+
y_tensor = y
140158

141159
return x_tensor, y_tensor

0 commit comments

Comments
 (0)