Skip to content

Commit d283507

Browse files
Fixes layer names in backbone and minor cleanup
1 parent d42bc6a commit d283507

File tree

8 files changed

+187
-356
lines changed

8 files changed

+187
-356
lines changed

models/experimental/panoptic_deeplab/common.py

Lines changed: 79 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ def map_single_key(checkpoint_key):
3030

3131
# BACKBONE MAPPINGS
3232
if key.startswith("backbone."):
33-
# Layer mapping: res2/3/4/5 -> layer1/2/3/4
34-
key = key.replace("res2", "layer1")
35-
key = key.replace("res3", "layer2")
36-
key = key.replace("res4", "layer3")
37-
key = key.replace("res5", "layer4")
38-
3933
# Batch norm mapping: conv1/2/3.norm -> bn1/2/3
4034
key = key.replace("conv1.norm", "bn1")
4135
key = key.replace("conv2.norm", "bn2")
@@ -47,8 +41,6 @@ def map_single_key(checkpoint_key):
4741
if ".shortcut." in key and ".shortcut.norm." not in checkpoint_key:
4842
key = key.replace(".shortcut.", ".downsample.0.")
4943

50-
return key
51-
5244
# SEMANTIC HEAD MAPPINGS
5345
elif key.startswith("sem_seg_head."):
5446
# Replace base prefix
@@ -68,88 +60,6 @@ def map_single_key(checkpoint_key):
6860
else:
6961
key = key.replace(".head.depthwise.", ".head_1.conv1.0.")
7062

71-
# ASPP mappings (res5 -> aspp)
72-
elif ".decoder.res5.project_conv." in key:
73-
# Special case for ASPP_3_Depthwise
74-
if ".convs.3.depthwise.norm." in key:
75-
key = key.replace(".decoder.res5.project_conv.convs.3.depthwise.norm.", ".aspp.ASPP_3_Depthwise.1.")
76-
elif ".convs.3.depthwise." in key:
77-
key = key.replace(".decoder.res5.project_conv.convs.3.depthwise.", ".aspp.ASPP_3_Depthwise.0.")
78-
79-
# ASPP_0_Conv
80-
elif ".convs.0.norm." in key:
81-
key = key.replace(".decoder.res5.project_conv.convs.0.norm.", ".aspp.ASPP_0_Conv.1.")
82-
elif ".convs.0." in key:
83-
key = key.replace(".decoder.res5.project_conv.convs.0.", ".aspp.ASPP_0_Conv.0.")
84-
85-
# ASPP_1 Depthwise and Pointwise
86-
elif ".convs.1.depthwise.norm." in key:
87-
key = key.replace(".decoder.res5.project_conv.convs.1.depthwise.norm.", ".aspp.ASPP_1_Depthwise.1.")
88-
elif ".convs.1.depthwise." in key:
89-
key = key.replace(".decoder.res5.project_conv.convs.1.depthwise.", ".aspp.ASPP_1_Depthwise.0.")
90-
elif ".convs.1.pointwise.norm." in key:
91-
key = key.replace(".decoder.res5.project_conv.convs.1.pointwise.norm.", ".aspp.ASPP_1_pointwise.1.")
92-
elif ".convs.1.pointwise." in key:
93-
key = key.replace(".decoder.res5.project_conv.convs.1.pointwise.", ".aspp.ASPP_1_pointwise.0.")
94-
95-
# ASPP_2 Depthwise and Pointwise
96-
elif ".convs.2.depthwise.norm." in key:
97-
key = key.replace(".decoder.res5.project_conv.convs.2.depthwise.norm.", ".aspp.ASPP_2_Depthwise.1.")
98-
elif ".convs.2.depthwise." in key:
99-
key = key.replace(".decoder.res5.project_conv.convs.2.depthwise.", ".aspp.ASPP_2_Depthwise.0.")
100-
elif ".convs.2.pointwise.norm." in key:
101-
key = key.replace(".decoder.res5.project_conv.convs.2.pointwise.norm.", ".aspp.ASPP_2_pointwise.1.")
102-
elif ".convs.2.pointwise." in key:
103-
key = key.replace(".decoder.res5.project_conv.convs.2.pointwise.", ".aspp.ASPP_2_pointwise.0.")
104-
105-
# ASPP_3 Pointwise
106-
elif ".convs.3.pointwise.norm." in key:
107-
key = key.replace(".decoder.res5.project_conv.convs.3.pointwise.norm.", ".aspp.ASPP_3_pointwise.1.")
108-
elif ".convs.3.pointwise." in key:
109-
key = key.replace(".decoder.res5.project_conv.convs.3.pointwise.", ".aspp.ASPP_3_pointwise.0.")
110-
111-
# ASPP_4_Conv
112-
elif ".convs.4." in key:
113-
key = key.replace(".decoder.res5.project_conv.convs.4.1.", ".aspp.ASPP_4_Conv_1.0.")
114-
115-
# ASPP project
116-
elif ".project.norm." in key:
117-
key = key.replace(".decoder.res5.project_conv.project.norm.", ".aspp.ASPP_project.1.")
118-
elif ".project." in key:
119-
key = key.replace(".decoder.res5.project_conv.project.", ".aspp.ASPP_project.0.")
120-
121-
# Decoder res3 mappings
122-
elif ".decoder.res3." in key:
123-
if ".project_conv.norm." in key:
124-
key = key.replace(".decoder.res3.project_conv.norm.", ".res3.conv1.1.")
125-
elif ".project_conv." in key:
126-
key = key.replace(".decoder.res3.project_conv.", ".res3.conv1.0.")
127-
elif ".fuse_conv.depthwise.norm." in key:
128-
key = key.replace(".decoder.res3.fuse_conv.depthwise.norm.", ".res3.conv2.1.")
129-
elif ".fuse_conv.depthwise." in key:
130-
key = key.replace(".decoder.res3.fuse_conv.depthwise.", ".res3.conv2.0.")
131-
elif ".fuse_conv.pointwise.norm." in key:
132-
key = key.replace(".decoder.res3.fuse_conv.pointwise.norm.", ".res3.conv3.1.")
133-
elif ".fuse_conv.pointwise." in key:
134-
key = key.replace(".decoder.res3.fuse_conv.pointwise.", ".res3.conv3.0.")
135-
136-
# Decoder res2 mappings
137-
elif ".decoder.res2." in key:
138-
if ".project_conv.norm." in key:
139-
key = key.replace(".decoder.res2.project_conv.norm.", ".res2.conv1.1.")
140-
elif ".project_conv." in key:
141-
key = key.replace(".decoder.res2.project_conv.", ".res2.conv1.0.")
142-
elif ".fuse_conv.depthwise.norm." in key:
143-
key = key.replace(".decoder.res2.fuse_conv.depthwise.norm.", ".res2.conv2.1.")
144-
elif ".fuse_conv.depthwise." in key:
145-
key = key.replace(".decoder.res2.fuse_conv.depthwise.", ".res2.conv2.0.")
146-
elif ".fuse_conv.pointwise.norm." in key:
147-
key = key.replace(".decoder.res2.fuse_conv.pointwise.norm.", ".res2.conv3.1.")
148-
elif ".fuse_conv.pointwise." in key:
149-
key = key.replace(".decoder.res2.fuse_conv.pointwise.", ".res2.conv3.0.")
150-
151-
return key
152-
15363
# INSTANCE HEAD MAPPINGS
15464
elif key.startswith("ins_embed_head."):
15565
# Replace base prefix
@@ -179,87 +89,85 @@ def map_single_key(checkpoint_key):
17989
elif ".offset_predictor." in key:
18090
key = key.replace(".offset_predictor.", ".head_1.conv3.0.")
18191

182-
# ASPP mappings (res5 -> aspp)
183-
elif ".decoder.res5.project_conv." in key:
184-
# Special case for ASPP_3_Depthwise
185-
if ".convs.3.depthwise.norm." in key:
186-
key = key.replace(".decoder.res5.project_conv.convs.3.depthwise.norm.", ".aspp.ASPP_3_Depthwise.1.")
187-
elif ".convs.3.depthwise." in key:
188-
key = key.replace(".decoder.res5.project_conv.convs.3.depthwise.", ".aspp.ASPP_3_Depthwise.0.")
189-
190-
# ASPP_0_Conv
191-
elif ".convs.0.norm." in key:
192-
key = key.replace(".decoder.res5.project_conv.convs.0.norm.", ".aspp.ASPP_0_Conv.1.")
193-
elif ".convs.0." in key:
194-
key = key.replace(".decoder.res5.project_conv.convs.0.", ".aspp.ASPP_0_Conv.0.")
195-
196-
# ASPP_1 Depthwise and Pointwise
197-
elif ".convs.1.depthwise.norm." in key:
198-
key = key.replace(".decoder.res5.project_conv.convs.1.depthwise.norm.", ".aspp.ASPP_1_Depthwise.1.")
199-
elif ".convs.1.depthwise." in key:
200-
key = key.replace(".decoder.res5.project_conv.convs.1.depthwise.", ".aspp.ASPP_1_Depthwise.0.")
201-
elif ".convs.1.pointwise.norm." in key:
202-
key = key.replace(".decoder.res5.project_conv.convs.1.pointwise.norm.", ".aspp.ASPP_1_pointwise.1.")
203-
elif ".convs.1.pointwise." in key:
204-
key = key.replace(".decoder.res5.project_conv.convs.1.pointwise.", ".aspp.ASPP_1_pointwise.0.")
205-
206-
# ASPP_2 Depthwise and Pointwise
207-
elif ".convs.2.depthwise.norm." in key:
208-
key = key.replace(".decoder.res5.project_conv.convs.2.depthwise.norm.", ".aspp.ASPP_2_Depthwise.1.")
209-
elif ".convs.2.depthwise." in key:
210-
key = key.replace(".decoder.res5.project_conv.convs.2.depthwise.", ".aspp.ASPP_2_Depthwise.0.")
211-
elif ".convs.2.pointwise.norm." in key:
212-
key = key.replace(".decoder.res5.project_conv.convs.2.pointwise.norm.", ".aspp.ASPP_2_pointwise.1.")
213-
elif ".convs.2.pointwise." in key:
214-
key = key.replace(".decoder.res5.project_conv.convs.2.pointwise.", ".aspp.ASPP_2_pointwise.0.")
215-
216-
# ASPP_3 Pointwise
217-
elif ".convs.3.pointwise.norm." in key:
218-
key = key.replace(".decoder.res5.project_conv.convs.3.pointwise.norm.", ".aspp.ASPP_3_pointwise.1.")
219-
elif ".convs.3.pointwise." in key:
220-
key = key.replace(".decoder.res5.project_conv.convs.3.pointwise.", ".aspp.ASPP_3_pointwise.0.")
221-
222-
# ASPP_4_Conv
223-
elif ".convs.4." in key:
224-
key = key.replace(".decoder.res5.project_conv.convs.4.1.", ".aspp.ASPP_4_Conv_1.0.")
225-
226-
# ASPP project
227-
elif ".project.norm." in key:
228-
key = key.replace(".decoder.res5.project_conv.project.norm.", ".aspp.ASPP_project.1.")
229-
elif ".project." in key:
230-
key = key.replace(".decoder.res5.project_conv.project.", ".aspp.ASPP_project.0.")
231-
232-
# Decoder res3 mappings
233-
elif ".decoder.res3." in key:
234-
if ".project_conv.norm." in key:
235-
key = key.replace(".decoder.res3.project_conv.norm.", ".res3.conv1.1.")
236-
elif ".project_conv." in key:
237-
key = key.replace(".decoder.res3.project_conv.", ".res3.conv1.0.")
238-
elif ".fuse_conv.depthwise.norm." in key:
239-
key = key.replace(".decoder.res3.fuse_conv.depthwise.norm.", ".res3.conv2.1.")
240-
elif ".fuse_conv.depthwise." in key:
241-
key = key.replace(".decoder.res3.fuse_conv.depthwise.", ".res3.conv2.0.")
242-
elif ".fuse_conv.pointwise.norm." in key:
243-
key = key.replace(".decoder.res3.fuse_conv.pointwise.norm.", ".res3.conv3.1.")
244-
elif ".fuse_conv.pointwise." in key:
245-
key = key.replace(".decoder.res3.fuse_conv.pointwise.", ".res3.conv3.0.")
246-
247-
# Decoder res2 mappings
248-
elif ".decoder.res2." in key:
249-
if ".project_conv.norm." in key:
250-
key = key.replace(".decoder.res2.project_conv.norm.", ".res2.conv1.1.")
251-
elif ".project_conv." in key:
252-
key = key.replace(".decoder.res2.project_conv.", ".res2.conv1.0.")
253-
elif ".fuse_conv.depthwise.norm." in key:
254-
key = key.replace(".decoder.res2.fuse_conv.depthwise.norm.", ".res2.conv2.1.")
255-
elif ".fuse_conv.depthwise." in key:
256-
key = key.replace(".decoder.res2.fuse_conv.depthwise.", ".res2.conv2.0.")
257-
elif ".fuse_conv.pointwise.norm." in key:
258-
key = key.replace(".decoder.res2.fuse_conv.pointwise.norm.", ".res2.conv3.1.")
259-
elif ".fuse_conv.pointwise." in key:
260-
key = key.replace(".decoder.res2.fuse_conv.pointwise.", ".res2.conv3.0.")
261-
262-
return key
92+
# ASPP mappings (res5 -> aspp)
93+
if ".decoder.res5.project_conv." in key:
94+
# Special case for ASPP_3_Depthwise
95+
if ".convs.3.depthwise.norm." in key:
96+
key = key.replace(".decoder.res5.project_conv.convs.3.depthwise.norm.", ".aspp.ASPP_3_Depthwise.1.")
97+
elif ".convs.3.depthwise." in key:
98+
key = key.replace(".decoder.res5.project_conv.convs.3.depthwise.", ".aspp.ASPP_3_Depthwise.0.")
99+
100+
# ASPP_0_Conv
101+
elif ".convs.0.norm." in key:
102+
key = key.replace(".decoder.res5.project_conv.convs.0.norm.", ".aspp.ASPP_0_Conv.1.")
103+
elif ".convs.0." in key:
104+
key = key.replace(".decoder.res5.project_conv.convs.0.", ".aspp.ASPP_0_Conv.0.")
105+
106+
# ASPP_1 Depthwise and Pointwise
107+
elif ".convs.1.depthwise.norm." in key:
108+
key = key.replace(".decoder.res5.project_conv.convs.1.depthwise.norm.", ".aspp.ASPP_1_Depthwise.1.")
109+
elif ".convs.1.depthwise." in key:
110+
key = key.replace(".decoder.res5.project_conv.convs.1.depthwise.", ".aspp.ASPP_1_Depthwise.0.")
111+
elif ".convs.1.pointwise.norm." in key:
112+
key = key.replace(".decoder.res5.project_conv.convs.1.pointwise.norm.", ".aspp.ASPP_1_pointwise.1.")
113+
elif ".convs.1.pointwise." in key:
114+
key = key.replace(".decoder.res5.project_conv.convs.1.pointwise.", ".aspp.ASPP_1_pointwise.0.")
115+
116+
# ASPP_2 Depthwise and Pointwise
117+
elif ".convs.2.depthwise.norm." in key:
118+
key = key.replace(".decoder.res5.project_conv.convs.2.depthwise.norm.", ".aspp.ASPP_2_Depthwise.1.")
119+
elif ".convs.2.depthwise." in key:
120+
key = key.replace(".decoder.res5.project_conv.convs.2.depthwise.", ".aspp.ASPP_2_Depthwise.0.")
121+
elif ".convs.2.pointwise.norm." in key:
122+
key = key.replace(".decoder.res5.project_conv.convs.2.pointwise.norm.", ".aspp.ASPP_2_pointwise.1.")
123+
elif ".convs.2.pointwise." in key:
124+
key = key.replace(".decoder.res5.project_conv.convs.2.pointwise.", ".aspp.ASPP_2_pointwise.0.")
125+
126+
# ASPP_3 Pointwise
127+
elif ".convs.3.pointwise.norm." in key:
128+
key = key.replace(".decoder.res5.project_conv.convs.3.pointwise.norm.", ".aspp.ASPP_3_pointwise.1.")
129+
elif ".convs.3.pointwise." in key:
130+
key = key.replace(".decoder.res5.project_conv.convs.3.pointwise.", ".aspp.ASPP_3_pointwise.0.")
131+
132+
# ASPP_4_Conv
133+
elif ".convs.4." in key:
134+
key = key.replace(".decoder.res5.project_conv.convs.4.1.", ".aspp.ASPP_4_Conv_1.0.")
135+
136+
# ASPP project
137+
elif ".project.norm." in key:
138+
key = key.replace(".decoder.res5.project_conv.project.norm.", ".aspp.ASPP_project.1.")
139+
elif ".project." in key:
140+
key = key.replace(".decoder.res5.project_conv.project.", ".aspp.ASPP_project.0.")
141+
142+
# Decoder res3 mappings
143+
elif ".decoder.res3." in key:
144+
if ".project_conv.norm." in key:
145+
key = key.replace(".decoder.res3.project_conv.norm.", ".res3.conv1.1.")
146+
elif ".project_conv." in key:
147+
key = key.replace(".decoder.res3.project_conv.", ".res3.conv1.0.")
148+
elif ".fuse_conv.depthwise.norm." in key:
149+
key = key.replace(".decoder.res3.fuse_conv.depthwise.norm.", ".res3.conv2.1.")
150+
elif ".fuse_conv.depthwise." in key:
151+
key = key.replace(".decoder.res3.fuse_conv.depthwise.", ".res3.conv2.0.")
152+
elif ".fuse_conv.pointwise.norm." in key:
153+
key = key.replace(".decoder.res3.fuse_conv.pointwise.norm.", ".res3.conv3.1.")
154+
elif ".fuse_conv.pointwise." in key:
155+
key = key.replace(".decoder.res3.fuse_conv.pointwise.", ".res3.conv3.0.")
156+
157+
# Decoder res2 mappings
158+
elif ".decoder.res2." in key:
159+
if ".project_conv.norm." in key:
160+
key = key.replace(".decoder.res2.project_conv.norm.", ".res2.conv1.1.")
161+
elif ".project_conv." in key:
162+
key = key.replace(".decoder.res2.project_conv.", ".res2.conv1.0.")
163+
elif ".fuse_conv.depthwise.norm." in key:
164+
key = key.replace(".decoder.res2.fuse_conv.depthwise.norm.", ".res2.conv2.1.")
165+
elif ".fuse_conv.depthwise." in key:
166+
key = key.replace(".decoder.res2.fuse_conv.depthwise.", ".res2.conv2.0.")
167+
elif ".fuse_conv.pointwise.norm." in key:
168+
key = key.replace(".decoder.res2.fuse_conv.pointwise.norm.", ".res2.conv3.1.")
169+
elif ".fuse_conv.pointwise." in key:
170+
key = key.replace(".decoder.res2.fuse_conv.pointwise.", ".res2.conv3.0.")
263171

264172
return key
265173

models/experimental/panoptic_deeplab/reference/resnet52_backbone.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def __init__(
3636
self.groups = groups
3737
self.base_width = width_per_group
3838
self.stem = DeepLabStem(in_channels=3, out_channels=self.inplanes, stride=1)
39-
self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dialate_config=dialate_layer_config[0])
40-
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dialate_config=dialate_layer_config[1])
41-
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dialate_config=dialate_layer_config[2])
42-
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dialate_config=[2, 4, 8])
39+
self.res2 = self._make_layer(block, 64, layers[0], stride=1, dialate_config=dialate_layer_config[0])
40+
self.res3 = self._make_layer(block, 128, layers[1], stride=2, dialate_config=dialate_layer_config[1])
41+
self.res4 = self._make_layer(block, 256, layers[2], stride=2, dialate_config=dialate_layer_config[2])
42+
self.res5 = self._make_layer(block, 512, layers[3], stride=1, dialate_config=[2, 4, 8])
4343

4444
def _make_layer(
4545
self,
@@ -83,10 +83,10 @@ def _make_layer(
8383
def forward(self, x: Tensor) -> Tensor:
8484
x = self.stem(x)
8585

86-
res_2 = self.layer1(x)
87-
res_3 = self.layer2(res_2)
88-
res_4 = self.layer3(res_3)
89-
res_5 = self.layer4(res_4)
86+
res_2 = self.res2(x)
87+
res_3 = self.res3(res_2)
88+
res_4 = self.res4(res_3)
89+
res_5 = self.res5(res_4)
9090
out = {"res_2": res_2, "res_3": res_3, "res_5": res_5}
9191

9292
return out

models/experimental/panoptic_deeplab/tests/test_resnet52_backbone.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
):
2727
super().__init__()
2828
if not hasattr(self, "_model_initialized"):
29-
torch.manual_seed(42) # Only seed once
29+
torch.manual_seed(42)
3030
self._model_initialized = True
3131
torch.cuda.manual_seed_all(42)
3232
torch.backends.cudnn.deterministic = True
@@ -63,15 +63,8 @@ def __init__(
6363
self.ttnn_model = TTBackbone(
6464
parameters=parameters,
6565
model_config=model_config,
66-
small_tensor=width < 2048,
6766
)
6867

69-
# First run configures convs JIT
70-
self.input_tensor = ttnn.to_device(tt_host_tensor, device)
71-
self.run()
72-
self.validate()
73-
74-
# Optimized run
7568
self.input_tensor = ttnn.to_device(tt_host_tensor, device)
7669
self.run()
7770
self.validate()
@@ -99,8 +92,7 @@ def validate(self, output_tensor=None):
9992
valid_pcc = {
10093
"res_2": 0.99,
10194
"res_3": 0.99,
102-
"res_4": 0.99,
103-
"res_5": 0.98,
95+
"res_5": 0.99,
10496
}
10597
self.pcc_passed_all = []
10698
self.pcc_message_all = []

0 commit comments

Comments
 (0)