|
9 | 9 | from models.utility_functions import pad_and_fold_conv_filters_for_unity_stride |
10 | 10 | from models.experimental.panoptic_deeplab.reference.resnet52_stem import DeepLabStem |
11 | 11 |
|
12 | | -from models.experimental.panoptic_deeplab.reference.head import ( |
13 | | - HeadModel, |
14 | | -) |
15 | | -from models.experimental.panoptic_deeplab.reference.res_block import ( |
16 | | - ResModel, |
17 | | -) |
18 | | -from models.experimental.panoptic_deeplab.reference.aspp import ( |
19 | | - ASPPModel, |
20 | | -) |
21 | 12 | from models.experimental.panoptic_deeplab.reference.decoder import ( |
22 | 13 | DecoderModel, |
23 | 14 | ) |
@@ -83,74 +74,21 @@ def custom_preprocessor( |
83 | 74 | parameters["conv2"]["bias"] = ttnn.from_torch(torch.reshape(conv2_bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper) |
84 | 75 | parameters["conv3"]["bias"] = ttnn.from_torch(torch.reshape(conv3_bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper) |
85 | 76 |
|
86 | | - elif isinstance(model, HeadModel): |
87 | | - for name, module in model.named_children(): |
88 | | - if hasattr(module, "__getitem__"): |
89 | | - if len(module) > 1 and hasattr(module[0], "weight") and hasattr(module[1], "weight"): |
90 | | - # Assume Conv + BN, fold BN into Conv |
91 | | - weight, bias = fold_batch_norm2d_into_conv2d(module[0], module[1]) |
92 | | - elif hasattr(module[0], "weight"): |
93 | | - # Just a Conv, no BN |
94 | | - weight = module[0].weight.clone().detach().contiguous() |
95 | | - bias = ( |
96 | | - module[0].bias.clone().detach().contiguous() |
97 | | - if module[0].bias is not None |
98 | | - else torch.zeros(module[0].out_channels) |
99 | | - ) |
100 | | - else: |
101 | | - continue |
102 | | - elif hasattr(module, "weight"): |
103 | | - # Single Conv2d |
104 | | - weight = module.weight.clone().detach().contiguous() |
105 | | - bias = ( |
106 | | - module.bias.clone().detach().contiguous() |
107 | | - if module.bias is not None |
108 | | - else torch.zeros(module.out_channels) |
109 | | - ) |
110 | | - else: |
111 | | - continue |
112 | | - |
113 | | - parameters[name] = {} |
114 | | - parameters[name]["weight"] = ttnn.from_torch(weight, mesh_mapper=mesh_mapper) |
115 | | - parameters[name]["bias"] = ttnn.from_torch(torch.reshape(bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper) |
116 | | - |
117 | | - elif isinstance(model, ASPPModel): |
118 | | - for name, module in model.named_children(): |
119 | | - # For each submodule (e.g., ASPP_0_Conv, ASPP_1_Depthwise, etc.) |
120 | | - if hasattr(module, "__getitem__"): |
121 | | - # If it's a Sequential or similar |
122 | | - if len(module) > 1 and hasattr(module[0], "weight") and hasattr(module[1], "weight"): |
123 | | - # Assume Conv + BN, fold BN into Conv |
124 | | - weight, bias = fold_batch_norm2d_into_conv2d(module[0], module[1]) |
125 | | - elif hasattr(module[0], "weight"): |
126 | | - # Just a Conv, no BN |
127 | | - weight = module[0].weight.clone().detach().contiguous() |
128 | | - bias = ( |
129 | | - module[0].bias.clone().detach().contiguous() |
130 | | - if module[0].bias is not None |
131 | | - else torch.zeros(module[0].out_channels) |
132 | | - ) |
133 | | - else: |
134 | | - continue |
135 | | - elif hasattr(module, "weight"): |
136 | | - # Single Conv2d |
137 | | - weight = module.weight.clone().detach().contiguous() |
138 | | - bias = ( |
139 | | - module.bias.clone().detach().contiguous() |
140 | | - if module.bias is not None |
141 | | - else torch.zeros(module.out_channels) |
142 | | - ) |
143 | | - else: |
144 | | - continue |
145 | | - |
146 | | - parameters[name] = {} |
147 | | - parameters[name]["weight"] = ttnn.from_torch(weight, mesh_mapper=mesh_mapper) |
148 | | - parameters[name]["bias"] = ttnn.from_torch(torch.reshape(bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper) |
| 77 | + elif isinstance(model, DecoderModel): |
| 78 | + parameters = {} |
| 79 | + # Let the sub-modules handle their own preprocessing |
| 80 | + for child_name, child in model.named_children(): |
| 81 | + parameters[child_name] = convert_torch_model_to_ttnn_model( |
| 82 | + child, |
| 83 | + name=f"{name}.{child_name}", |
| 84 | + custom_preprocessor=custom_preprocessor_func, |
| 85 | + convert_to_ttnn=convert_to_ttnn, |
| 86 | + ttnn_module_args=ttnn_module_args, |
| 87 | + ) |
149 | 88 |
|
150 | | - elif isinstance(model, ResModel): |
| 89 | + else: |
151 | 90 | for name, module in model.named_children(): |
152 | 91 | if hasattr(module, "__getitem__"): |
153 | | - # If it's a Sequential or similar |
154 | 92 | if len(module) > 1 and hasattr(module[0], "weight") and hasattr(module[1], "weight"): |
155 | 93 | # Assume Conv + BN, fold BN into Conv |
156 | 94 | weight, bias = fold_batch_norm2d_into_conv2d(module[0], module[1]) |
@@ -179,18 +117,6 @@ def custom_preprocessor( |
179 | 117 | parameters[name]["weight"] = ttnn.from_torch(weight, mesh_mapper=mesh_mapper) |
180 | 118 | parameters[name]["bias"] = ttnn.from_torch(torch.reshape(bias, (1, 1, 1, -1)), mesh_mapper=mesh_mapper) |
181 | 119 |
|
182 | | - elif isinstance(model, DecoderModel): |
183 | | - parameters = {} |
184 | | - # Let the sub-modules handle their own preprocessing |
185 | | - for child_name, child in model.named_children(): |
186 | | - parameters[child_name] = convert_torch_model_to_ttnn_model( |
187 | | - child, |
188 | | - name=f"{name}.{child_name}", |
189 | | - custom_preprocessor=custom_preprocessor_func, |
190 | | - convert_to_ttnn=convert_to_ttnn, |
191 | | - ttnn_module_args=ttnn_module_args, |
192 | | - ) |
193 | | - |
194 | 120 | return parameters |
195 | 121 |
|
196 | 122 |
|
|
0 commit comments