@@ -306,9 +306,14 @@ def __init__(
306306 rtol : float = 1e-03 ,
307307 qtol : int = 1 ,
308308 dynamic_shapes : Optional [Tuple [Any ]] = None ,
309+ tosa_extensions : Optional [List [str ]] = None ,
309310 ):
311+ if tosa_extensions is None :
312+ tosa_extensions = []
310313 tosa_profiles = {
311- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
314+ "1.0" : TosaSpecification .create_from_string (
315+ "TOSA-1.0+INT" + "" .join ([f"+{ ext } " for ext in tosa_extensions ])
316+ ),
312317 }
313318 tosa_version = conftest .get_option ("tosa_version" )
314319
@@ -406,9 +411,14 @@ def __init__(
406411 transform_passes : Optional [
407412 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
408413 ] = None ,
414+ tosa_extensions : Optional [List [str ]] = None ,
409415 ):
416+ if tosa_extensions is None :
417+ tosa_extensions = []
410418 tosa_profiles = {
411- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
419+ "1.0" : TosaSpecification .create_from_string (
420+ "TOSA-1.0+FP" + "" .join ([f"+{ ext } " for ext in tosa_extensions ])
421+ ),
412422 }
413423 tosa_version = conftest .get_option ("tosa_version" )
414424
@@ -655,10 +665,15 @@ def __init__(
655665 pass_functions : Optional [List [Callable ]] = None ,
656666 passes_with_exported_program : Optional [List [Type [ExportPass ]]] = None ,
657667 custom_path : str = None ,
668+ tosa_extensions : Optional [List [str ]] = None ,
658669 ):
670+ if tosa_extensions is None :
671+ tosa_extensions = []
659672 tosa_profiles = {
660673 "1.0" : TosaSpecification .create_from_string (
661- "TOSA-1.0+" + ("INT" if quantize else "FP" )
674+ "TOSA-1.0+"
675+ + ("INT" if quantize else "FP" )
676+ + "" .join ([f"+{ ext } " for ext in tosa_extensions ]),
662677 ),
663678 }
664679 tosa_version = conftest .get_option ("tosa_version" )
@@ -721,9 +736,14 @@ def __init__(
721736 module : torch .nn .Module ,
722737 test_data : T ,
723738 custom_path : str = None ,
739+ tosa_extensions : Optional [List [str ]] = None ,
724740 ):
741+ if tosa_extensions is None :
742+ tosa_extensions = []
725743 tosa_profiles = {
726- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
744+ "1.0" : TosaSpecification .create_from_string (
745+ "TOSA-1.0+INT" + "" .join ([f"+{ ext } " for ext in tosa_extensions ]),
746+ ),
727747 }
728748 tosa_version = conftest .get_option ("tosa_version" )
729749
@@ -779,18 +799,23 @@ def __init__(
779799 custom_path : str = None ,
780800 quantize : Optional [bool ] = False ,
781801 u55_subset : Optional [bool ] = False ,
802+ tosa_extensions : Optional [List [str ]] = None ,
782803 ):
804+ if tosa_extensions is None :
805+ tosa_extensions = []
783806 tosa_profiles = {
784- "1.0" : "TOSA-1.0+" + ("INT" if quantize else "FP" ),
807+ "1.0" : TosaSpecification .create_from_string (
808+ "TOSA-1.0+"
809+ + ("INT" if quantize else "FP" )
810+ + ("+u55" if u55_subset and quantize else "" )
811+ + "" .join ([f"+{ ext } " for ext in tosa_extensions ]),
812+ ),
785813 }
786- tosa_version = tosa_profiles [ conftest .get_option ("tosa_version" )]
814+ tosa_version = conftest .get_option ("tosa_version" )
787815
788- if u55_subset and quantize :
789- tosa_version = f"{ tosa_version } +u55"
816+ tosa_spec = tosa_profiles [tosa_version ]
790817
791- compile_spec = common .get_tosa_compile_spec (
792- tosa_version , custom_path = custom_path
793- )
818+ compile_spec = common .get_tosa_compile_spec (tosa_spec , custom_path = custom_path )
794819 super ().__init__ (
795820 module ,
796821 test_data ,
@@ -799,7 +824,7 @@ def __init__(
799824 [],
800825 )
801826
802- if "INT" in tosa_version :
827+ if tosa_spec . support_integer () :
803828 self .add_stage (self .tester .quantize , pos = 0 )
804829
805830 self .change_args ("check_not.exir" , [])
@@ -855,11 +880,16 @@ def __init__(
855880 transform_passes : Optional [
856881 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
857882 ] = None ,
883+ tosa_extensions : Optional [List [str ]] = None ,
858884 ):
859885
860- tosa_profile = TosaSpecification .create_from_string (tosa_version )
886+ if tosa_extensions is None :
887+ tosa_extensions = []
888+ tosa_spec = TosaSpecification .create_from_string (
889+ tosa_version + "" .join ([f"+{ ext } " for ext in tosa_extensions ])
890+ )
861891 compile_spec = common .get_vgf_compile_spec (
862- tosa_profile , compiler_flags = vgf_compiler_flags , custom_path = custom_path
892+ tosa_spec , compiler_flags = vgf_compiler_flags , custom_path = custom_path
863893 )
864894
865895 super ().__init__ (
@@ -873,7 +903,7 @@ def __init__(
873903 transform_passes = transform_passes ,
874904 )
875905
876- if "INT" in tosa_version :
906+ if tosa_spec . support_integer () :
877907 quantizer = VgfQuantizer (compile_spec )
878908 quantization_config = get_symmetric_quantization_config (
879909 is_per_channel = per_channel_quantization
0 commit comments