Skip to content

Commit 0e76e35

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
scripting support for Instances.cat and Instances.__init__(with arguments)
Summary: needed for single stage detector refactoring Reviewed By: newstzpz Differential Revision: D30977541 fbshipit-source-id: b596a9996d98caf538a22c8e2e191d4039d013b6
1 parent 592128b commit 0e76e35

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

detectron2/export/torchscript_patch.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,19 @@ def indent(level, s):
113113
cls_name = "ScriptedInstances{}".format(_counter)
114114

115115
field_names = tuple(x.name for x in fields)
116+
extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
116117
lines.append(
117118
f"""
118119
class {cls_name}:
119-
def __init__(self, image_size: Tuple[int, int]):
120+
def __init__(self, image_size: Tuple[int, int], {extra_args}):
120121
self.image_size = image_size
121122
self._field_names = {field_names}
122123
"""
123124
)
124125

125126
for f in fields:
126127
lines.append(
127-
indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], None)")
128+
indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
128129
)
129130

130131
for f in fields:
@@ -135,7 +136,7 @@ def {f.name}(self) -> {f.annotation}:
135136
# has to use a local for type refinement
136137
# https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
137138
t = self._{f.name}
138-
assert t is not None
139+
assert t is not None, "{f.name} is None and cannot be accessed!"
139140
return t
140141
141142
@{f.name}.setter
@@ -184,10 +185,11 @@ def has(self, name: str) -> bool:
184185
)
185186

186187
# support method `to`
188+
none_args = ", None" * len(fields)
187189
lines.append(
188190
f"""
189191
def to(self, device: torch.device) -> "{cls_name}":
190-
ret = {cls_name}(self.image_size)
192+
ret = {cls_name}(self.image_size{none_args})
191193
"""
192194
)
193195
for f in fields:
@@ -210,10 +212,11 @@ def to(self, device: torch.device) -> "{cls_name}":
210212
)
211213

212214
# support method `getitem`
215+
none_args = ", None" * len(fields)
213216
lines.append(
214217
f"""
215218
def __getitem__(self, item) -> "{cls_name}":
216-
ret = {cls_name}(self.image_size)
219+
ret = {cls_name}(self.image_size{none_args})
217220
"""
218221
)
219222
for f in fields:
@@ -230,6 +233,32 @@ def __getitem__(self, item) -> "{cls_name}":
230233
"""
231234
)
232235

236+
# support method `cat`
237+
# this version does not contain checks that all instances have same size and fields
238+
none_args = ", None" * len(fields)
239+
lines.append(
240+
f"""
241+
def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
242+
ret = {cls_name}(self.image_size{none_args})
243+
"""
244+
)
245+
for f in fields:
246+
lines.append(
247+
f"""
248+
t = self._{f.name}
249+
if t is not None:
250+
values: List[{f.annotation}] = [x.{f.name} for x in instances]
251+
if torch.jit.isinstance(t, torch.Tensor):
252+
ret._{f.name} = torch.cat(values, dim=0)
253+
else:
254+
ret._{f.name} = t.cat(values)
255+
"""
256+
)
257+
lines.append(
258+
"""
259+
return ret"""
260+
)
261+
233262
# support method `get_fields()`
234263
lines.append(
235264
"""

tests/structures/test_instances.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,41 @@ def test_from_to_instances(self):
181181
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new1.proposal_boxes.tensor))
182182
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new2.proposal_boxes.tensor))
183183

184+
def test_script_init_args(self):
185+
def f(x: Tensor):
186+
image_shape = (15, 15)
187+
# __init__ can take arguments
188+
inst = Instances(image_shape, a=x, proposal_boxes=Boxes(x))
189+
inst2 = Instances(image_shape, a=x)
190+
return inst.a, inst2.a
191+
192+
fields = {"proposal_boxes": Boxes, "a": Tensor}
193+
with patch_instances(fields):
194+
script_f = torch.jit.script(f)
195+
x = torch.randn(3, 4)
196+
outputs = script_f(x)
197+
self.assertTrue(torch.equal(outputs[0], x))
198+
self.assertTrue(torch.equal(outputs[1], x))
199+
200+
def test_script_cat(self):
201+
def f(x: Tensor):
202+
image_shape = (15, 15)
203+
# __init__ can take arguments
204+
inst = Instances(image_shape, a=x)
205+
inst2 = Instances(image_shape, a=x)
206+
207+
inst3 = Instances(image_shape, proposal_boxes=Boxes(x))
208+
return inst.cat([inst, inst2]), inst3.cat([inst3, inst3])
209+
210+
fields = {"proposal_boxes": Boxes, "a": Tensor}
211+
with patch_instances(fields):
212+
script_f = torch.jit.script(f)
213+
x = torch.randn(3, 4)
214+
output, output2 = script_f(x)
215+
self.assertTrue(torch.equal(output.a, torch.cat([x, x])))
216+
self.assertFalse(output.has("proposal_boxes"))
217+
self.assertTrue(torch.equal(output2.proposal_boxes.tensor, torch.cat([x, x])))
218+
184219

185220
if __name__ == "__main__":
186221
unittest.main()

0 commit comments

Comments
 (0)