diff --git a/ai_edge_torch/_convert/test/test_convert.py b/ai_edge_torch/_convert/test/test_convert.py index ebb0d8597..09538d60e 100644 --- a/ai_edge_torch/_convert/test/test_convert.py +++ b/ai_edge_torch/_convert/test/test_convert.py @@ -91,6 +91,78 @@ def forward(self, a, b, c): model_coverage.compare_tflite_torch(edge_model, torch_module, args) ) + def test_convert_conv2d_x1(self): + """Tests conversion of a simple Conv2d module.""" + + class Conv2d(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + return self.conv(x) + + args = (torch.randn((1, 3, 224, 224)),) + torch_module = Conv2d().eval() + edge_model = ai_edge_torch.convert(torch_module, args) + + tmp_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + tmp_dir_path = os.path.join(tmp_dir_name, "conv2d_x1.tflite") + edge_model.export(tmp_dir_path) + + self.assertTrue( + model_coverage.compare_tflite_torch(edge_model, torch_module, args) + ) + + def test_convert_conv2d_add(self): + """Tests conversion of Conv2d layers with add ops.""" + + class Conv2d_add(nn.Module): + + def __init__(self): + super().__init__() + self.convs = nn.ModuleList() + self.convs.append( + nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ) + ) + for _ in range(14): + self.convs.append( + nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ) + ) + + def forward(self, x): + x = self.convs[0](x) + for i in range(1, 15): + x = x + self.convs[i](x) + return x + + args = (torch.randn((1, 3, 224, 224)),) + torch_module = Conv2d_add().eval() + edge_model = ai_edge_torch.convert(torch_module, args) + + tmp_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + tmp_dir_path = os.path.join(tmp_dir_name, "conv2d_add_x14.tflite") + edge_model.export(tmp_dir_path) + + self.assertTrue( + model_coverage.compare_tflite_torch(edge_model, torch_module, args) + ) + def test_convert_resnet18(self): args = (torch.randn(4, 3, 224, 224),) torch_module = torchvision.models.resnet18().eval() diff --git a/ai_edge_torch/testing/model_coverage/model_coverage.py b/ai_edge_torch/testing/model_coverage/model_coverage.py index 3994007b8..eec4093ce 100644 --- a/ai_edge_torch/testing/model_coverage/model_coverage.py +++ b/ai_edge_torch/testing/model_coverage/model_coverage.py @@ -64,6 +64,34 @@ def _torch_tensors_to_np(*argv): raise ValueError("Unsupported torch.tensor type.") +def _print_diff(tensor_idx, tflite_out, torch_out): + """Prints difference details between two tensors.""" + diff = np.abs(tflite_out - torch_out) + max_abs_diff = np.max(diff) + avg_abs_diff = np.mean(diff) + print(f"Tensor {tensor_idx} difference:") + print(f" PyTorch result: {torch_out}") + print(f" TFLite result: {tflite_out}") + print(f" Difference: {diff}") + print(f" Max absolute difference: {max_abs_diff}") + print(f" Mean absolute difference: {avg_abs_diff}") + top10_diffs = np.sort(diff.flatten())[-10:][::-1] + print(f" Top 10 differences: {top10_diffs}") + nonzero_indices = np.abs(torch_out) > 0 + if np.any(nonzero_indices): + rel_diff = diff[nonzero_indices] / np.abs(torch_out[nonzero_indices]) + max_rel_diff_percent = np.max(rel_diff) * 100 + mean_rel_diff_percent = np.mean(rel_diff) * 100 + print( + " Max relative difference (for non-zero golden values):" + f" {max_rel_diff_percent:.2f}%" + ) + print( + " Mean relative difference (for non-zero golden values):" + f" {mean_rel_diff_percent:.2f}%" + ) + + def compare_tflite_torch( edge_model: model.Model, torch_eval_func: Callable, @@ -72,7 +100,7 @@ def compare_tflite_torch( *, num_valid_inputs: int = 1, signature_name: str = None, - atol: float = 1e-5, + atol: float = 1e-4, rtol: float = 1e-5 ): """Compares torch models and TFLite models. @@ -140,6 +168,16 @@ def get_edge_output(inputs): for out, golden_out in zip(output, golden_output) ]) if not is_equal: + print("TFLite and PyTorch results are different.") + if not is_output_len_eq: + print( + "Output length mismatch:" + f" TFLite {len(output)}, PyTorch {len(golden_output)}" + ) + return False + for i, (out, golden_out) in enumerate(zip(output, golden_output)): + if not np.allclose(out, golden_out, atol=atol, rtol=rtol): + _print_diff(i, out, golden_out) return False return True