| 
52 | 52 | 
 
  | 
53 | 53 | from executorch.backends.cadence.aot.typing_stubs import expand  | 
54 | 54 | from executorch.exir.dialects._ops import ops as exir_ops  | 
55 |  | -from executorch.exir.pass_base import ExportPass  | 
 | 55 | +from executorch.exir.pass_base import ExportPass, ProxyValue  | 
56 | 56 | from executorch.exir.passes import dead_code_elimination_pass  | 
57 | 57 | from torch.fx.passes.infra.pass_base import PassResult  | 
 | 58 | +from torch.utils import _pytree as pytree  | 
58 | 59 | 
 
  | 
59 | 60 | 
 
  | 
60 | 61 | class TestReplaceOpsPasses(unittest.TestCase):  | 
@@ -345,6 +346,194 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split(  | 
345 | 346 |             count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x  | 
346 | 347 |         )  | 
347 | 348 | 
 
  | 
 | 349 | +    def assertTensorMetadataIsSame(  | 
 | 350 | +        self, a: Sequence[torch.Tensor], b: Sequence[torch.Tensor]  | 
 | 351 | +    ) -> None:  | 
 | 352 | +        for i, (_a, _b) in enumerate(zip(a, b)):  | 
 | 353 | +            # TODO: actually compare the tensors.  | 
 | 354 | +            self.assertTrue(  | 
 | 355 | +                _a.shape == _b.shape, f"Tensor {i}: {_a.shape} != {_b.shape}"  | 
 | 356 | +            )  | 
 | 357 | +            self.assertTrue(  | 
 | 358 | +                _a.dtype == _b.dtype, f"Tensor {i}: {_a.dtype} != {_b.dtype}"  | 
 | 359 | +            )  | 
 | 360 | + | 
 | 361 | +    @expand(  | 
 | 362 | +        [  | 
 | 363 | +            [(1, 8, 18), 8, 16, 3],  | 
 | 364 | +            [(1, 8, 18), 8, 16, 5, 2],  | 
 | 365 | +            # depthwise + bias  | 
 | 366 | +            [(1, 8, 18), 8, 16, 5, 2, 0, 1, True],  | 
 | 367 | +            # no bias  | 
 | 368 | +            [(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],  | 
 | 369 | +            # bias + transposed  | 
 | 370 | +            [(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],  | 
 | 371 | +            # Stride of 2 needed.  | 
 | 372 | +            [(1, 8, 3), 8, 8, 48, 2, 23],  | 
 | 373 | +        ]  | 
 | 374 | +    )  | 
 | 375 | +    @torch.no_grad()  | 
 | 376 | +    def test_replace_aten_conv_with_cadence_conv(  | 
 | 377 | +        self,  | 
 | 378 | +        shape: Tuple[int, ...],  | 
 | 379 | +        in_channels: int,  | 
 | 380 | +        out_channels: int,  | 
 | 381 | +        kernel: int,  | 
 | 382 | +        stride: int = 1,  | 
 | 383 | +        padding: int = 0,  | 
 | 384 | +        dilation: int = 1,  | 
 | 385 | +        depthwise: bool = False,  | 
 | 386 | +        bias_enabled: bool = True,  | 
 | 387 | +        output_padding: Optional[int] = None,  | 
 | 388 | +    ) -> None:  | 
 | 389 | +        groups = in_channels if depthwise else 1  | 
 | 390 | +        builder = GraphBuilder()  | 
 | 391 | +        x_tensor = torch.randn(*shape, dtype=torch.float32)  | 
 | 392 | +        x = builder.placeholder("x", x_tensor)  | 
 | 393 | +        weights_tensor = torch.randn(  | 
 | 394 | +            [out_channels, in_channels // groups, kernel], dtype=torch.float32  | 
 | 395 | +        )  | 
 | 396 | +        weights = builder.placeholder("weights", weights_tensor)  | 
 | 397 | +        bias: Optional[ProxyValue] = None  | 
 | 398 | +        bias_tensor: Optional[torch.Tensor] = None  | 
 | 399 | +        if bias_enabled:  | 
 | 400 | +            bias_tensor = torch.randn([out_channels], dtype=torch.float32)  | 
 | 401 | +            bias = builder.placeholder("bias", bias_tensor)  | 
 | 402 | +        convolution = builder.call_operator(  | 
 | 403 | +            op=exir_ops.edge.aten.convolution.default,  | 
 | 404 | +            args=(  | 
 | 405 | +                x,  | 
 | 406 | +                weights,  | 
 | 407 | +                bias,  | 
 | 408 | +                [stride],  | 
 | 409 | +                [padding],  | 
 | 410 | +                [dilation],  | 
 | 411 | +                False,  | 
 | 412 | +                [output_padding] if output_padding else [0],  | 
 | 413 | +                groups,  | 
 | 414 | +            ),  | 
 | 415 | +        )  | 
 | 416 | +        builder.output([convolution])  | 
 | 417 | +        original_gm = builder.get_graph_module()  | 
 | 418 | + | 
 | 419 | +        replacement_pass_result = (  | 
 | 420 | +            ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)  | 
 | 421 | +        )  | 
 | 422 | +        self.assertIsNotNone(replacement_pass_result)  | 
 | 423 | +        graph_after_passes = replacement_pass_result.graph_module  | 
 | 424 | + | 
 | 425 | +        self.assertEqual(  | 
 | 426 | +            count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),  | 
 | 427 | +            0,  | 
 | 428 | +        )  | 
 | 429 | +        self.assertEqual(  | 
 | 430 | +            count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),  | 
 | 431 | +            1,  | 
 | 432 | +        )  | 
 | 433 | +        self.assertEqual(  | 
 | 434 | +            count_node(  | 
 | 435 | +                graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default  | 
 | 436 | +            ),  | 
 | 437 | +            0,  | 
 | 438 | +        )  | 
 | 439 | + | 
 | 440 | +        inputs = (x.to_tensor(), weights.to_tensor())  | 
 | 441 | +        if bias is not None:  | 
 | 442 | +            inputs += (bias.to_tensor(),)  | 
 | 443 | +        self.assertTensorMetadataIsSame(  | 
 | 444 | +            pytree.tree_flatten(original_gm.forward(*inputs))[0],  | 
 | 445 | +            pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],  | 
 | 446 | +        )  | 
 | 447 | + | 
 | 448 | +    @expand(  | 
 | 449 | +        [  | 
 | 450 | +            [(1, 8, 18), 8, 16, 3],  | 
 | 451 | +            [(1, 8, 18), 8, 16, 5, 2],  | 
 | 452 | +            # depthwise + bias  | 
 | 453 | +            [(1, 8, 18), 8, 16, 5, 2, 0, 1, True, True],  | 
 | 454 | +            # no bias  | 
 | 455 | +            [(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],  | 
 | 456 | +            # depthwise + no bias  | 
 | 457 | +            [(1, 8, 18), 8, 16, 3, 1, 0, 1, True, False],  | 
 | 458 | +            # bias  | 
 | 459 | +            [(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],  | 
 | 460 | +        ]  | 
 | 461 | +    )  | 
 | 462 | +    @torch.no_grad()  | 
 | 463 | +    def test_replace_aten_transposed_conv_with_cadence_transposed_conv(  | 
 | 464 | +        self,  | 
 | 465 | +        shape: Tuple[int, ...],  | 
 | 466 | +        in_channels: int,  | 
 | 467 | +        out_channels: int,  | 
 | 468 | +        kernel: int,  | 
 | 469 | +        stride: int = 1,  | 
 | 470 | +        padding: int = 0,  | 
 | 471 | +        dilation: int = 1,  | 
 | 472 | +        depthwise: bool = False,  | 
 | 473 | +        bias_enabled: bool = True,  | 
 | 474 | +        output_padding: Optional[int] = None,  | 
 | 475 | +    ) -> None:  | 
 | 476 | +        groups = in_channels if depthwise else 1  | 
 | 477 | +        builder = GraphBuilder()  | 
 | 478 | +        x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))  | 
 | 479 | +        weights_shape = [in_channels, out_channels // groups, kernel]  | 
 | 480 | +        weights = builder.placeholder(  | 
 | 481 | +            "weights",  | 
 | 482 | +            torch.randn(weights_shape, dtype=torch.float32),  | 
 | 483 | +        )  | 
 | 484 | +        bias = (  | 
 | 485 | +            builder.placeholder(  | 
 | 486 | +                "bias", torch.randn([out_channels], dtype=torch.float32)  | 
 | 487 | +            )  | 
 | 488 | +            if bias_enabled  | 
 | 489 | +            else None  | 
 | 490 | +        )  | 
 | 491 | +        convolution = builder.call_operator(  | 
 | 492 | +            op=exir_ops.edge.aten.convolution.default,  | 
 | 493 | +            args=(  | 
 | 494 | +                x,  | 
 | 495 | +                weights,  | 
 | 496 | +                bias,  | 
 | 497 | +                [stride],  | 
 | 498 | +                [padding],  | 
 | 499 | +                [dilation],  | 
 | 500 | +                True,  | 
 | 501 | +                [output_padding] if output_padding else [0],  | 
 | 502 | +                groups,  | 
 | 503 | +            ),  | 
 | 504 | +        )  | 
 | 505 | +        builder.output([convolution])  | 
 | 506 | +        original_gm = builder.get_graph_module()  | 
 | 507 | + | 
 | 508 | +        replacement_pass_result = (  | 
 | 509 | +            ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)  | 
 | 510 | +        )  | 
 | 511 | +        self.assertIsNotNone(replacement_pass_result)  | 
 | 512 | +        graph_after_passes = replacement_pass_result.graph_module  | 
 | 513 | + | 
 | 514 | +        self.assertEqual(  | 
 | 515 | +            count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),  | 
 | 516 | +            0,  | 
 | 517 | +        )  | 
 | 518 | +        self.assertEqual(  | 
 | 519 | +            count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),  | 
 | 520 | +            0,  | 
 | 521 | +        )  | 
 | 522 | +        self.assertEqual(  | 
 | 523 | +            count_node(  | 
 | 524 | +                graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default  | 
 | 525 | +            ),  | 
 | 526 | +            1,  | 
 | 527 | +        )  | 
 | 528 | + | 
 | 529 | +        inputs = (x.to_tensor(), weights.to_tensor())  | 
 | 530 | +        if bias is not None:  | 
 | 531 | +            inputs += (bias.to_tensor(),)  | 
 | 532 | +        self.assertTensorMetadataIsSame(  | 
 | 533 | +            pytree.tree_flatten(original_gm.forward(*inputs))[0],  | 
 | 534 | +            pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],  | 
 | 535 | +        )  | 
 | 536 | + | 
348 | 537 |     @expand(  | 
349 | 538 |         [  | 
350 | 539 |             [(1, 8, 33), 8, 16, 3],  | 
 | 
0 commit comments