Skip to content

Commit 11bb130

Browse files
committed
custom_preprocessor refactor
1 parent b486e87 commit 11bb130

File tree

1 file changed

+12
-86
lines changed

1 file changed

+12
-86
lines changed

models/experimental/panoptic_deeplab/tt/custom_preprocessing.py

Lines changed: 12 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,6 @@
99
from models.utility_functions import pad_and_fold_conv_filters_for_unity_stride
1010
from models.experimental.panoptic_deeplab.reference.resnet52_stem import DeepLabStem
1111

12-
from models.experimental.panoptic_deeplab.reference.head import (
13-
HeadModel,
14-
)
15-
from models.experimental.panoptic_deeplab.reference.res_block import (
16-
ResModel,
17-
)
18-
from models.experimental.panoptic_deeplab.reference.aspp import (
19-
ASPPModel,
20-
)
2112
from models.experimental.panoptic_deeplab.reference.decoder import (
2213
DecoderModel,
2314
)
@@ -83,74 +74,21 @@ def custom_preprocessor(
8374
parameters["conv2"]["bias"] = ttnn.from_torch(torch.reshape(conv2_bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper)
8475
parameters["conv3"]["bias"] = ttnn.from_torch(torch.reshape(conv3_bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper)
8576

86-
elif isinstance(model, HeadModel):
87-
for name, module in model.named_children():
88-
if hasattr(module, "__getitem__"):
89-
if len(module) > 1 and hasattr(module[0], "weight") and hasattr(module[1], "weight"):
90-
# Assume Conv + BN, fold BN into Conv
91-
weight, bias = fold_batch_norm2d_into_conv2d(module[0], module[1])
92-
elif hasattr(module[0], "weight"):
93-
# Just a Conv, no BN
94-
weight = module[0].weight.clone().detach().contiguous()
95-
bias = (
96-
module[0].bias.clone().detach().contiguous()
97-
if module[0].bias is not None
98-
else torch.zeros(module[0].out_channels)
99-
)
100-
else:
101-
continue
102-
elif hasattr(module, "weight"):
103-
# Single Conv2d
104-
weight = module.weight.clone().detach().contiguous()
105-
bias = (
106-
module.bias.clone().detach().contiguous()
107-
if module.bias is not None
108-
else torch.zeros(module.out_channels)
109-
)
110-
else:
111-
continue
112-
113-
parameters[name] = {}
114-
parameters[name]["weight"] = ttnn.from_torch(weight, mesh_mapper=mesh_mapper)
115-
parameters[name]["bias"] = ttnn.from_torch(torch.reshape(bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper)
116-
117-
elif isinstance(model, ASPPModel):
118-
for name, module in model.named_children():
119-
# For each submodule (e.g., ASPP_0_Conv, ASPP_1_Depthwise, etc.)
120-
if hasattr(module, "__getitem__"):
121-
# If it's a Sequential or similar
122-
if len(module) > 1 and hasattr(module[0], "weight") and hasattr(module[1], "weight"):
123-
# Assume Conv + BN, fold BN into Conv
124-
weight, bias = fold_batch_norm2d_into_conv2d(module[0], module[1])
125-
elif hasattr(module[0], "weight"):
126-
# Just a Conv, no BN
127-
weight = module[0].weight.clone().detach().contiguous()
128-
bias = (
129-
module[0].bias.clone().detach().contiguous()
130-
if module[0].bias is not None
131-
else torch.zeros(module[0].out_channels)
132-
)
133-
else:
134-
continue
135-
elif hasattr(module, "weight"):
136-
# Single Conv2d
137-
weight = module.weight.clone().detach().contiguous()
138-
bias = (
139-
module.bias.clone().detach().contiguous()
140-
if module.bias is not None
141-
else torch.zeros(module.out_channels)
142-
)
143-
else:
144-
continue
145-
146-
parameters[name] = {}
147-
parameters[name]["weight"] = ttnn.from_torch(weight, mesh_mapper=mesh_mapper)
148-
parameters[name]["bias"] = ttnn.from_torch(torch.reshape(bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper)
77+
elif isinstance(model, DecoderModel):
78+
parameters = {}
79+
# Let the sub-modules handle their own preprocessing
80+
for child_name, child in model.named_children():
81+
parameters[child_name] = convert_torch_model_to_ttnn_model(
82+
child,
83+
name=f"{name}.{child_name}",
84+
custom_preprocessor=custom_preprocessor_func,
85+
convert_to_ttnn=convert_to_ttnn,
86+
ttnn_module_args=ttnn_module_args,
87+
)
14988

150-
elif isinstance(model, ResModel):
89+
else:
15190
for name, module in model.named_children():
15291
if hasattr(module, "__getitem__"):
153-
# If it's a Sequential or similar
15492
if len(module) > 1 and hasattr(module[0], "weight") and hasattr(module[1], "weight"):
15593
# Assume Conv + BN, fold BN into Conv
15694
weight, bias = fold_batch_norm2d_into_conv2d(module[0], module[1])
@@ -179,18 +117,6 @@ def custom_preprocessor(
179117
parameters[name]["weight"] = ttnn.from_torch(weight, mesh_mapper=mesh_mapper)
180118
parameters[name]["bias"] = ttnn.from_torch(torch.reshape(bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper)
181119

182-
elif isinstance(model, DecoderModel):
183-
parameters = {}
184-
# Let the sub-modules handle their own preprocessing
185-
for child_name, child in model.named_children():
186-
parameters[child_name] = convert_torch_model_to_ttnn_model(
187-
child,
188-
name=f"{name}.{child_name}",
189-
custom_preprocessor=custom_preprocessor_func,
190-
convert_to_ttnn=convert_to_ttnn,
191-
ttnn_module_args=ttnn_module_args,
192-
)
193-
194120
return parameters
195121

196122

0 commit comments

Comments
 (0)