|
19 | 19 | from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( |
20 | 20 | PartiallyMaterializedTensor, |
21 | 21 | ) |
22 | | -from hypothesis import given, settings, strategies as st, Verbosity |
| 22 | +from hypothesis import assume, given, settings, strategies as st, Verbosity |
23 | 23 | from torch import distributed as dist |
24 | 24 | from torch.distributed._shard.sharded_tensor import ShardedTensor |
25 | 25 | from torch.distributed._tensor import DTensor |
@@ -624,11 +624,10 @@ def test_load_state_dict( |
624 | 624 | kernel_type: str, |
625 | 625 | is_training: bool, |
626 | 626 | ) -> None: |
627 | | - if ( |
628 | | - self.device == torch.device("cpu") |
629 | | - and kernel_type != EmbeddingComputeKernel.FUSED.value |
630 | | - ): |
631 | | - self.skipTest("CPU does not support uvm.") |
| 627 | + assume( |
| 628 | + self.device != torch.device("cpu") |
| 629 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 630 | + ) |
632 | 631 |
|
633 | 632 | sharders = [ |
634 | 633 | cast( |
@@ -683,11 +682,10 @@ def test_optimizer_load_state_dict( |
683 | 682 | sharding_type: str, |
684 | 683 | kernel_type: str, |
685 | 684 | ) -> None: |
686 | | - if ( |
687 | | - self.device == torch.device("cpu") |
688 | | - and kernel_type != EmbeddingComputeKernel.FUSED.value |
689 | | - ): |
690 | | - self.skipTest("CPU does not support uvm.") |
| 685 | + assume( |
| 686 | + self.device != torch.device("cpu") |
| 687 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 688 | + ) |
691 | 689 |
|
692 | 690 | sharders = [ |
693 | 691 | cast( |
@@ -800,11 +798,10 @@ def test_load_state_dict_dp( |
800 | 798 | def test_load_state_dict_prefix( |
801 | 799 | self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool |
802 | 800 | ) -> None: |
803 | | - if ( |
804 | | - self.device == torch.device("cpu") |
805 | | - and kernel_type != EmbeddingComputeKernel.FUSED.value |
806 | | - ): |
807 | | - self.skipTest("CPU does not support uvm.") |
| 801 | + assume( |
| 802 | + self.device != torch.device("cpu") |
| 803 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 804 | + ) |
808 | 805 |
|
809 | 806 | sharders = [ |
810 | 807 | cast( |
@@ -855,11 +852,10 @@ def test_load_state_dict_prefix( |
855 | 852 | def test_params_and_buffers( |
856 | 853 | self, sharder_type: str, sharding_type: str, kernel_type: str |
857 | 854 | ) -> None: |
858 | | - if ( |
859 | | - self.device == torch.device("cpu") |
860 | | - and kernel_type != EmbeddingComputeKernel.FUSED.value |
861 | | - ): |
862 | | - self.skipTest("CPU does not support uvm.") |
| 855 | + assume( |
| 856 | + self.device != torch.device("cpu") |
| 857 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 858 | + ) |
863 | 859 |
|
864 | 860 | sharders = [ |
865 | 861 | create_test_sharder(sharder_type, sharding_type, kernel_type), |
@@ -897,11 +893,10 @@ def test_params_and_buffers( |
897 | 893 | def test_load_state_dict_cw_multiple_shards( |
898 | 894 | self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool |
899 | 895 | ) -> None: |
900 | | - if ( |
901 | | - self.device == torch.device("cpu") |
902 | | - and kernel_type != EmbeddingComputeKernel.FUSED.value |
903 | | - ): |
904 | | - self.skipTest("CPU does not support uvm.") |
| 896 | + assume( |
| 897 | + self.device != torch.device("cpu") |
| 898 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 899 | + ) |
905 | 900 |
|
906 | 901 | sharders = [ |
907 | 902 | cast( |
|
0 commit comments