@@ -326,56 +326,72 @@ def call_operator(self, op, args, kwargs, meta):
326326 self .assertTrue (is_contiguous_dim_order (expected ))
327327
328328 def test_op_clone_replacement_channels_last_survives (self ):
329- _clone_dim_order_op_str = (
330- "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
331- )
329+ clone_op_cases = [
330+ # Case testing aten.clone by setting _skip_dim_order to True
331+ (True , "executorch_exir_dialects_edge__ops_aten_clone_default" ),
332+ # Case testing _clone_dim_order by setting _skip_dim_order to False
333+ (
334+ False ,
335+ "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" ,
336+ ),
337+ ]
332338
333- model = SimpleCloneChannelsLastModule ()
334- x = torch .randn (3 , 4 , 5 , 6 ).to (memory_format = torch .contiguous_format )
339+ for skip_dim_order , clone_op_str in clone_op_cases :
340+ model = SimpleCloneChannelsLastModule ()
341+ x = torch .randn (3 , 4 , 5 , 6 ).to (memory_format = torch .contiguous_format )
335342
336- exported = export (model .eval (), (x ,), strict = True )
337- before_epm = to_edge (
338- exported , compile_config = EdgeCompileConfig (_skip_dim_order = False )
339- )
343+ exported = export (model .eval (), (x ,), strict = True )
344+ before_epm = to_edge (
345+ exported ,
346+ compile_config = EdgeCompileConfig (_skip_dim_order = skip_dim_order ),
347+ )
340348
341- updated_epm = before_epm .transform ([RemoveCloneOpsTransform ()])
349+ updated_epm = before_epm .transform ([RemoveCloneOpsTransform ()])
342350
343- FileCheck ().check_count (_clone_dim_order_op_str , 1 , exactly = True ).run (
344- updated_epm .exported_program ().graph_module .code
345- )
351+ FileCheck ().check_count (clone_op_str , 1 , exactly = True ).run (
352+ updated_epm .exported_program ().graph_module .code
353+ )
346354
347- expected = before_epm .exported_program ().module ()(x )
348- actual = updated_epm .exported_program ().module ()(x )
349- assert torch .allclose (actual , expected )
350- assert is_channel_last_dim_order (actual )
355+ expected = before_epm .exported_program ().module ()(x )
356+ actual = updated_epm .exported_program ().module ()(x )
357+ assert torch .allclose (actual , expected )
358+ assert is_channel_last_dim_order (actual )
351359
352360 def test_op_clone_without_transformation_removed (self ):
353- _clone_dim_order_op_str = (
354- "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
355- )
361+ clone_op_cases = [
362+ # Case testing aten.clone by setting _skip_dim_order to True
363+ (True , "executorch_exir_dialects_edge__ops_aten_clone_default" ),
364+ # Case testing _clone_dim_order by setting _skip_dim_order to False
365+ (
366+ False ,
367+ "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" ,
368+ ),
369+ ]
356370
357- model = SimpleCloneChannelsLastModule ()
358- x = torch .randn (3 , 4 , 5 , 6 ).to (memory_format = torch .channels_last )
371+ for skip_dim_order , clone_op_str in clone_op_cases :
372+ model = SimpleCloneChannelsLastModule ()
373+ x = torch .randn (3 , 4 , 5 , 6 ).to (memory_format = torch .channels_last )
359374
360- exported = export (model .eval (), (x ,), strict = True )
361- before_epm = to_edge (
362- exported , compile_config = EdgeCompileConfig (_skip_dim_order = False )
363- )
375+ exported = export (model .eval (), (x ,), strict = True )
376+ before_epm = to_edge (
377+ exported ,
378+ compile_config = EdgeCompileConfig (_skip_dim_order = skip_dim_order ),
379+ )
364380
365- FileCheck ().check_count (_clone_dim_order_op_str , 1 , exactly = True ).run (
366- before_epm .exported_program ().graph_module .code
367- )
381+ FileCheck ().check_count (clone_op_str , 1 , exactly = True ).run (
382+ before_epm .exported_program ().graph_module .code
383+ )
368384
369- updated_epm = before_epm .transform ([RemoveCloneOpsTransform ()])
385+ updated_epm = before_epm .transform ([RemoveCloneOpsTransform ()])
370386
371- FileCheck ().check_not (_clone_dim_order_op_str ).run (
372- updated_epm .exported_program ().graph_module .code
373- )
387+ FileCheck ().check_not (clone_op_str ).run (
388+ updated_epm .exported_program ().graph_module .code
389+ )
374390
375- expected = before_epm .exported_program ().module ()(x )
376- actual = updated_epm .exported_program ().module ()(x )
377- assert torch .allclose (actual , expected )
378- assert is_channel_last_dim_order (actual )
391+ expected = before_epm .exported_program ().module ()(x )
392+ actual = updated_epm .exported_program ().module ()(x )
393+ assert torch .allclose (actual , expected )
394+ assert is_channel_last_dim_order (actual )
379395
380396 def test_resnet18 (self ) -> None :
381397 model = torchvision .models .resnet18 ()
0 commit comments