3636backends = testing_reqs .backends
3737compute_units = testing_reqs .compute_units
3838
39- TORCH_EXPORT_DEFAULT_LOWER_BOUND = {TorchFrontend .TORCHEXPORT : 2 , TorchFrontend .EXECUTORCH : 2 }
40- if torch .__version__ >= "2.4.0" :
41- TORCH_EXPORT_DEFAULT_LOWER_BOUND [TorchFrontend .TORCHEXPORT ] = 0
42-
4339
4440class TestTorchExportConversionAPI (TorchBaseTest ):
4541 @pytest .mark .parametrize ("frontend" , frontends )
@@ -160,7 +156,10 @@ def forward(self, x):
160156 )[1 ]
161157 input_proto = coreml_model .input_description ._fd_spec [0 ]
162158 size_ranges = input_proto .type .multiArrayType .shapeRange .sizeRanges
163- assert size_ranges [0 ].lowerBound == TORCH_EXPORT_DEFAULT_LOWER_BOUND [frontend ]
159+ assert size_ranges [0 ].lowerBound == {
160+ TorchFrontend .TORCHEXPORT : 0 ,
161+ TorchFrontend .EXECUTORCH : 2 ,
162+ }[frontend ]
164163 assert size_ranges [0 ].upperBound == 2147483647
165164 assert size_ranges [1 ].lowerBound == 3
166165 assert size_ranges [1 ].upperBound == 3
@@ -351,7 +350,7 @@ def forward(self, input, other):
351350
352351 dynamic_shapes = None
353352 if dynamic :
354- dim0 = torch .export .Dim ( "dim0" )
353+ dim0 = torch .export .Dim . AUTO
355354 dim1 = torch .export .Dim ("dim1" , min = 1 , max = 3 )
356355 dynamic_shapes = {
357356 "input" : {0 : dim0 , 1 : dim1 },
@@ -370,11 +369,12 @@ def forward(self, input, other):
370369 if dynamic :
371370 for input_proto in coreml_model .input_description ._fd_spec :
372371 size_ranges = input_proto .type .multiArrayType .shapeRange .sizeRanges
373- assert size_ranges [0 ].lowerBound == TORCH_EXPORT_DEFAULT_LOWER_BOUND [ frontend ]
372+ assert size_ranges [0 ].lowerBound == 2
374373 assert size_ranges [0 ].upperBound == 2147483647
375- assert size_ranges [1 ].lowerBound == max (
376- 1 , TORCH_EXPORT_DEFAULT_LOWER_BOUND [frontend ]
377- )
374+ assert size_ranges [1 ].lowerBound == {
375+ TorchFrontend .TORCHEXPORT : 1 ,
376+ TorchFrontend .EXECUTORCH : 2 ,
377+ }[frontend ]
378378 assert size_ranges [1 ].upperBound == 3
379379
380380 mil_program = coreml_model ._mil_program
@@ -451,7 +451,10 @@ def forward(self, arg):
451451 if dynamic :
452452 input_proto = coreml_model .input_description ._fd_spec [0 ]
453453 size_ranges = input_proto .type .multiArrayType .shapeRange .sizeRanges
454- assert size_ranges [0 ].lowerBound == TORCH_EXPORT_DEFAULT_LOWER_BOUND [frontend ]
454+ assert size_ranges [0 ].lowerBound == {
455+ TorchFrontend .TORCHEXPORT : 0 ,
456+ TorchFrontend .EXECUTORCH : 2 ,
457+ }[frontend ]
455458 assert size_ranges [0 ].upperBound == 2147483647
456459 assert size_ranges [1 ].lowerBound == 3
457460 assert size_ranges [1 ].upperBound == 3
@@ -637,10 +640,16 @@ def forward(self, a, x, b):
637640 if i == 0 :
638641 assert size_ranges [0 ].lowerBound == 2
639642 assert size_ranges [0 ].upperBound == 2
640- assert size_ranges [1 ].lowerBound == TORCH_EXPORT_DEFAULT_LOWER_BOUND [frontend ]
643+ assert size_ranges [1 ].lowerBound == {
644+ TorchFrontend .TORCHEXPORT : 0 ,
645+ TorchFrontend .EXECUTORCH : 2 ,
646+ }[frontend ]
641647 assert size_ranges [1 ].upperBound == 2147483647
642648 elif i == 1 :
643- assert size_ranges [0 ].lowerBound == TORCH_EXPORT_DEFAULT_LOWER_BOUND [frontend ]
649+ assert size_ranges [0 ].lowerBound == {
650+ TorchFrontend .TORCHEXPORT : 0 ,
651+ TorchFrontend .EXECUTORCH : 2 ,
652+ }[frontend ]
644653 assert size_ranges [0 ].upperBound == 2147483647
645654 assert size_ranges [1 ].lowerBound == 2
646655 assert size_ranges [1 ].upperBound == 2
@@ -745,7 +754,10 @@ def forward(self, x):
745754 if dynamic :
746755 input_proto = coreml_model .input_description ._fd_spec [0 ]
747756 size_ranges = input_proto .type .multiArrayType .shapeRange .sizeRanges
748- assert size_ranges [0 ].lowerBound == TORCH_EXPORT_DEFAULT_LOWER_BOUND [frontend ]
757+ assert size_ranges [0 ].lowerBound == {
758+ TorchFrontend .TORCHEXPORT : 0 ,
759+ TorchFrontend .EXECUTORCH : 2 ,
760+ }[frontend ]
749761 assert size_ranges [0 ].upperBound == 2147483647
750762 assert size_ranges [1 ].lowerBound == 2
751763 assert size_ranges [1 ].upperBound == 2
0 commit comments