Skip to content

Commit d42bc6a

Browse files
Trained weight loading for unit tests
1 parent 3917c03 commit d42bc6a

File tree

16 files changed

+175
-133
lines changed

16 files changed

+175
-133
lines changed

models/experimental/panoptic_deeplab/common.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from models.experimental.panoptic_deeplab.reference.resnet52_backbone import ResNet52BackBone as TorchBackbone
1212
from models.experimental.panoptic_deeplab.reference.resnet52_stem import DeepLabStem
1313
from torchvision.models.resnet import Bottleneck
14+
from models.experimental.panoptic_deeplab.reference.aspp import ASPPModel
15+
from models.experimental.panoptic_deeplab.reference.decoder import DecoderModel
16+
from models.experimental.panoptic_deeplab.reference.res_block import ResModel
17+
from models.experimental.panoptic_deeplab.reference.head import HeadModel
18+
from models.experimental.panoptic_deeplab.reference.panoptic_deeplab import TorchPanopticDeepLab
1419

1520

1621
def map_single_key(checkpoint_key):
@@ -259,17 +264,18 @@ def map_single_key(checkpoint_key):
259264
return key
260265

261266

262-
def load_partial_state(torch_model: torch.nn.Module, state_dict, layer_prefix: str = ""):
267+
def load_partial_state(torch_model: torch.nn.Module, state_dict, layer_name: str = ""):
263268
partial_state_dict = {}
269+
layer_prefix = layer_name + "."
264270
for k, v in state_dict.items():
265271
if k.startswith(layer_prefix):
266272
partial_state_dict[k[len(layer_prefix) :]] = v
267273
torch_model.load_state_dict(partial_state_dict, strict=True)
268274
logger.info(f"Successfully loaded all mapped weights with strict=True")
269-
return torch_model.eval()
275+
return torch_model
270276

271277

272-
def load_torch_model_state(model: torch.nn.Module = None, layer_name: str = "", model_location_generator=None):
278+
def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str = "", model_location_generator=None):
273279
if model_location_generator == None or "TT_GH_CI_INFRA" not in os.environ:
274280
model_path = "models"
275281
else:
@@ -312,16 +318,24 @@ def load_torch_model_state(model: torch.nn.Module = None, layer_name: str = "",
312318
for checkpoint_key, model_key in key_mapping.items():
313319
mapped_state_dict[model_key] = state_dict[checkpoint_key]
314320

315-
if model is None:
316-
return mapped_state_dict
317-
elif isinstance(model, TorchBackbone):
318-
layer_prefix = "backbone."
319-
return load_partial_state(model, mapped_state_dict, layer_prefix)
320-
elif isinstance(model, DeepLabStem):
321-
layer_prefix = "backbone.stem."
322-
return load_partial_state(model, mapped_state_dict, layer_prefix)
323-
elif isinstance(model, Bottleneck):
324-
layer_prefix = "backbone." + layer_name + "."
325-
return load_partial_state(model, mapped_state_dict, layer_prefix)
321+
if isinstance(
322+
torch_model,
323+
(
324+
DeepLabStem,
325+
Bottleneck,
326+
TorchBackbone,
327+
ASPPModel,
328+
ResModel,
329+
HeadModel,
330+
DecoderModel,
331+
),
332+
):
333+
torch_model = load_partial_state(torch_model, mapped_state_dict, layer_name)
334+
elif isinstance(torch_model, TorchPanopticDeepLab):
335+
del mapped_state_dict["pixel_mean"]
336+
del mapped_state_dict["pixel_std"]
337+
torch_model.load_state_dict(mapped_state_dict, strict=True)
326338
else:
327339
raise NotImplementedError("Unknown torch model. Weight loading not implemented")
340+
341+
return torch_model.eval()

models/experimental/panoptic_deeplab/reference/decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, name) -> None:
2222
super().__init__()
2323
self.name = name
2424
self.aspp = ASPPModel()
25-
if name == "semantics_head":
25+
if name == "semantic_decoder":
2626
self.res3 = ResModel(512, 320, 256)
2727
self.res2 = ResModel(256, 288, 256)
2828
self.head_1 = HeadModel(256, 256, 19)
@@ -48,7 +48,7 @@ def forward(self, x: Tensor, res3: Tensor, res2: Tensor) -> Tuple[Tensor, Tensor
4848
out_ = self.res2(out, res2)
4949
out = self.head_1(out_)
5050

51-
if self.name == "instance_head":
51+
if self.name == "instance_decoder":
5252
out_2 = self.head_2(out_)
5353
else:
5454
out_2 = None

models/experimental/panoptic_deeplab/reference/head.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,24 @@ def __init__(self, in_channels, intermediate_channels, out_channels) -> None:
2222

2323
if out_channels == 1: # instance center head
2424
self.conv1 = nn.Sequential(
25-
nn.Conv2d(in_channels, in_channels, 3, 1, 1, 1), nn.BatchNorm2d(in_channels), nn.ReLU()
25+
nn.Conv2d(in_channels, in_channels, 3, 1, 1, 1, bias=False), nn.BatchNorm2d(in_channels), nn.ReLU()
2626
)
2727

2828
self.conv2 = nn.Sequential(
29-
nn.Conv2d(in_channels, intermediate_channels, 3, 1, 1, 1),
29+
nn.Conv2d(in_channels, intermediate_channels, 3, 1, 1, 1, bias=False),
3030
nn.BatchNorm2d(intermediate_channels),
3131
nn.ReLU(),
3232
)
3333
else: # instance offset head and semantics head
3434
self.conv1 = nn.Sequential(
35-
nn.Conv2d(in_channels, in_channels, 5, 1, 2, 1, in_channels), nn.BatchNorm2d(in_channels), nn.ReLU()
35+
nn.Conv2d(in_channels, in_channels, 5, 1, 2, 1, in_channels, bias=False),
36+
nn.BatchNorm2d(in_channels),
37+
nn.ReLU(),
3638
)
3739
self.conv2 = nn.Sequential(
38-
nn.Conv2d(in_channels, intermediate_channels, 1, 1), nn.BatchNorm2d(intermediate_channels), nn.ReLU()
40+
nn.Conv2d(in_channels, intermediate_channels, 1, 1, bias=False),
41+
nn.BatchNorm2d(intermediate_channels),
42+
nn.ReLU(),
3943
)
4044
self.conv3 = nn.Sequential(nn.Conv2d(intermediate_channels, out_channels, 1, 1))
4145

models/experimental/panoptic_deeplab/reference/panoptic_deeplab.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,24 @@ def __init__(
2020
) -> None:
2121
super().__init__()
2222

23+
# self.pixel_std = nn.Parameter(torch.randn((3, 1, 1)))
24+
# self.pixel_mean = nn.Parameter(torch.randn((3, 1, 1)))
25+
# self.register_buffer("pixel_mean", torch.randn(3).view(-1, 1, 1), False)
26+
# self.register_buffer("pixel_std", torch.randn(3).view(-1, 1, 1), False)
27+
# self.register_buffer("adsaf", torch.randn(3).view(-1, 1, 1), False)
28+
# self.register_buffer("yurfdgdf", torch.randn(3).view(-1, 1, 1), False)
29+
2330
# Backbone
2431
self.backbone = ResNet52BackBone()
2532

2633
# Semantic segmentation decoder
2734
self.semantic_decoder = DecoderModel(
28-
name="Semantics_head",
35+
name="semantic_decoder",
2936
)
3037

3138
# Instance segmentation decoders
3239
self.instance_decoder = DecoderModel(
33-
name="instance_head",
40+
name="instance_decoder",
3441
)
3542

3643
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

models/experimental/panoptic_deeplab/tests/test_aspp.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99

1010
from tests.ttnn.utils_for_testing import check_with_pcc
1111
from models.experimental.panoptic_deeplab.tt.custom_preprocessing import create_custom_mesh_preprocessor
12-
from models.experimental.panoptic_deeplab.reference.aspp import (
13-
ASPPModel,
14-
)
12+
from models.experimental.panoptic_deeplab.reference.aspp import ASPPModel
1513
from models.experimental.panoptic_deeplab.tt.aspp import TTASPP
14+
from models.experimental.panoptic_deeplab.common import load_torch_model_state
1615

1716

1817
class AsppTestInfra:
@@ -24,6 +23,7 @@ def __init__(
2423
height,
2524
width,
2625
model_config,
26+
name,
2727
):
2828
super().__init__()
2929
if not hasattr(self, "_model_initialized"):
@@ -37,19 +37,20 @@ def __init__(
3737
self.num_devices = device.get_num_devices()
3838
self.batch_size = batch_size
3939
self.inputs_mesh_mapper, self.weights_mesh_mapper, self.output_mesh_composer = self.get_mesh_mappers(device)
40+
self.name = name
4041

4142
# torch model
42-
torch_model = ASPPModel().eval()
43-
self.torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.float16)
43+
torch_model = ASPPModel()
44+
torch_model = load_torch_model_state(torch_model, name)
45+
4446
parameters = preprocess_model_parameters(
4547
initialize_model=lambda: torch_model,
4648
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper),
4749
device=None,
4850
)
49-
torch_model.to(torch.bfloat16)
50-
self.torch_input_tensor = self.torch_input_tensor.to(torch.bfloat16)
5151

5252
# golden
53+
self.torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.float)
5354
self.torch_output_tensor = torch_model(self.torch_input_tensor)
5455

5556
# ttnn
@@ -67,7 +68,6 @@ def __init__(
6768
# run and validate
6869
self.run()
6970
self.validate()
70-
ttnn.deallocate(self.output_tensor)
7171

7272
def get_mesh_mappers(self, device):
7373
if device.get_num_devices() != 1:
@@ -85,23 +85,30 @@ def run(self):
8585
return self.output_tensor
8686

8787
def validate(self, output_tensor=None):
88-
"""Validate outputs"""
88+
tt_output_tensor = self.output_tensor if output_tensor is None else output_tensor
89+
tt_output_tensor_torch = ttnn.to_torch(
90+
tt_output_tensor, device=self.device, mesh_composer=self.output_mesh_composer
91+
)
92+
93+
# Deallocate output tesnors
94+
ttnn.deallocate(tt_output_tensor)
8995

90-
output_tensor = self.output_tensor if output_tensor is None else output_tensor
91-
output_tensor = ttnn.to_torch(output_tensor, device=self.device, mesh_composer=self.output_mesh_composer)
9296
expected_shape = self.torch_output_tensor.shape
93-
output_tensor = torch.reshape(
94-
output_tensor, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1])
97+
tt_output_tensor_torch = torch.reshape(
98+
tt_output_tensor_torch, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1])
9599
)
96-
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))
97-
batch_size = self.batch_size
100+
tt_output_tensor_torch = torch.permute(tt_output_tensor_torch, (0, 3, 1, 2))
101+
102+
batch_size = tt_output_tensor_torch.shape[0]
98103

99-
valid_pcc = 0.97
100-
self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc)
104+
valid_pcc = 0.99
105+
self.pcc_passed, self.pcc_message = check_with_pcc(
106+
self.torch_output_tensor, tt_output_tensor_torch, pcc=valid_pcc
107+
)
101108

102109
assert self.pcc_passed, logger.error(f"PCC check failed: {self.pcc_message}")
103110
logger.info(
104-
f"Modular Panoptic DeepLab ASPP - batch_size={batch_size}, act_dtype={model_config['ACTIVATIONS_DTYPE']}, weight_dtype={model_config['WEIGHTS_DTYPE']}, math_fidelity={model_config['MATH_FIDELITY']}, PCC={self.pcc_message}"
111+
f"Modular Panoptic DeepLab ASPP layer:{self.name} - batch_size={batch_size}, act_dtype={model_config['ACTIVATIONS_DTYPE']}, weight_dtype={model_config['WEIGHTS_DTYPE']}, math_fidelity={model_config['MATH_FIDELITY']}, PCC={self.pcc_message}"
105112
)
106113

107114
return self.pcc_passed, self.pcc_message
@@ -121,12 +128,14 @@ def validate(self, output_tensor=None):
121128
(1, 2048, 32, 64),
122129
],
123130
)
124-
def test_aspp(device, batch_size, input_channels, height, width):
131+
@pytest.mark.parametrize("name", ["semantic_decoder.aspp", "instance_decoder.aspp"])
132+
def test_aspp(device, batch_size, input_channels, height, width, name):
125133
AsppTestInfra(
126134
device,
127135
batch_size,
128136
input_channels,
129137
height,
130138
width,
131139
model_config,
140+
name,
132141
)

models/experimental/panoptic_deeplab/tests/test_decoder.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
decoder_layer_optimisations,
1414
)
1515
from models.experimental.panoptic_deeplab.tt.custom_preprocessing import create_custom_mesh_preprocessor
16-
from models.experimental.panoptic_deeplab.reference.decoder import (
17-
DecoderModel,
18-
)
16+
from models.experimental.panoptic_deeplab.reference.decoder import DecoderModel
17+
from models.experimental.panoptic_deeplab.common import load_torch_model_state
1918

2019

2120
class DecoderTestInfra:
@@ -56,28 +55,26 @@ def __init__(
5655

5756
# Create input tensors
5857
self.torch_input_tensor = torch.randn(
59-
(self.batch_size, self.in_channels, self.height, self.width), dtype=torch.float32
58+
(self.batch_size, self.in_channels, self.height, self.width), dtype=torch.float
6059
)
6160

6261
# Create res3 and res2 feature maps with appropriate dimensions
63-
self.torch_res3_tensor = torch.randn(
64-
(self.batch_size, 512, self.height * 2, self.width * 2), dtype=torch.float32
65-
)
62+
self.torch_res3_tensor = torch.randn((self.batch_size, 512, self.height * 2, self.width * 2), dtype=torch.float)
6663

6764
self.torch_res2_tensor = torch.randn(
68-
(self.batch_size, upsample_channels, self.height * 4, self.width * 4), dtype=torch.float32
65+
(self.batch_size, upsample_channels, self.height * 4, self.width * 4), dtype=torch.float
6966
)
7067

7168
# torch model
72-
torch_model = DecoderModel(self.name).eval()
69+
torch_model = DecoderModel(self.name)
70+
torch_model = load_torch_model_state(torch_model, name)
7371

7472
parameters = preprocess_model_parameters(
7573
initialize_model=lambda: torch_model,
7674
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper),
7775
device=None,
7876
)
7977

80-
parameters.conv_args = {}
8178
# For ASPP
8279
aspp_args = infer_ttnn_module_args(
8380
model=torch_model.aspp, run_model=lambda model: model(self.torch_input_tensor), device=None
@@ -116,12 +113,6 @@ def __init__(
116113
if hasattr(parameters, "head_2"):
117114
parameters.head_2.conv_args = head_2_args
118115

119-
# Convert to bfloat16
120-
torch_model.to(torch.bfloat16)
121-
self.torch_input_tensor = self.torch_input_tensor.to(torch.bfloat16)
122-
self.torch_res3_tensor = self.torch_res3_tensor.to(torch.bfloat16)
123-
self.torch_res2_tensor = self.torch_res2_tensor.to(torch.bfloat16)
124-
125116
# Get torch output with all three inputs
126117
self.torch_output_tensor, self.torch_output_tensor_2 = torch_model(
127118
self.torch_input_tensor, self.torch_res3_tensor, self.torch_res2_tensor
@@ -189,7 +180,7 @@ def validate(self, output_tensor=None):
189180

190181
batch_size = output_tensor.shape[0]
191182

192-
valid_pcc = 0.97
183+
valid_pcc = 0.99
193184
self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc)
194185
assert self.pcc_passed, logger.error(f"PCC check failed: {self.pcc_message}")
195186
logger.info(
@@ -207,7 +198,7 @@ def validate(self, output_tensor=None):
207198

208199
batch_size = output_tensor.shape[0]
209200

210-
valid_pcc = 0.96
201+
valid_pcc = 0.99
211202
self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor_2, output_tensor, pcc=valid_pcc)
212203
assert self.pcc_passed, logger.error(f"PCC check failed: {self.pcc_message}")
213204
logger.info(
@@ -228,8 +219,8 @@ def validate(self, output_tensor=None):
228219
@pytest.mark.parametrize(
229220
"batch_size, in_channels, res3_intermediate_channels, res2_intermediate_channels, out_channels, upsample_channels, height, width, name",
230221
[
231-
(1, 2048, 320, 288, (19,), 256, 32, 64, "semantics_head"), # semantic head
232-
(1, 2048, 320, 160, (2, 1), 256, 32, 64, "instance_head"), # instance offset head
222+
(1, 2048, 320, 288, (19,), 256, 32, 64, "semantic_decoder"), # semantic head
223+
(1, 2048, 320, 160, (2, 1), 256, 32, 64, "instance_decoder"), # instance offset head
233224
],
234225
)
235226
def test_decoder(

0 commit comments

Comments
 (0)