@@ -467,11 +467,10 @@ def test_sharding_cw(
467467 data_type : DataType ,
468468 allow_zero_batch_size : bool ,
469469 ) -> None :
470- if (
471- self .device == torch .device ("cpu" )
472- and kernel_type != EmbeddingComputeKernel .FUSED .value
473- ):
474- self .skipTest ("CPU does not support uvm." )
470+ assume (
471+ self .device != torch .device ("cpu" )
472+ or kernel_type == EmbeddingComputeKernel .FUSED .value
473+ )
475474
476475 sharding_type = ShardingType .COLUMN_WISE .value
477476 assume (
@@ -548,11 +547,10 @@ def test_sharding_twcw(
548547 variable_batch_size : bool ,
549548 data_type : DataType ,
550549 ) -> None :
551- if (
552- self .device == torch .device ("cpu" )
553- and kernel_type != EmbeddingComputeKernel .FUSED .value
554- ):
555- self .skipTest ("CPU does not support uvm." )
550+ assume (
551+ self .device != torch .device ("cpu" )
552+ or kernel_type == EmbeddingComputeKernel .FUSED .value
553+ )
556554
557555 sharding_type = ShardingType .TABLE_COLUMN_WISE .value
558556 assume (
@@ -629,11 +627,10 @@ def test_sharding_tw(
629627 variable_batch_size : bool ,
630628 data_type : DataType ,
631629 ) -> None :
632- if (
633- self .device == torch .device ("cpu" )
634- and kernel_type != EmbeddingComputeKernel .FUSED .value
635- ):
636- self .skipTest ("CPU does not support uvm." )
630+ assume (
631+ self .device != torch .device ("cpu" )
632+ or kernel_type == EmbeddingComputeKernel .FUSED .value
633+ )
637634
638635 sharding_type = ShardingType .TABLE_WISE .value
639636 assume (
0 commit comments