@@ -306,9 +306,14 @@ def __init__(
306
306
rtol : float = 1e-03 ,
307
307
qtol : int = 1 ,
308
308
dynamic_shapes : Optional [Tuple [Any ]] = None ,
309
+ tosa_extensions : Optional [List [str ]] = None ,
309
310
):
311
+ if tosa_extensions is None :
312
+ tosa_extensions = []
310
313
tosa_profiles = {
311
- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
314
+ "1.0" : TosaSpecification .create_from_string (
315
+ "TOSA-1.0+INT" + "" .join ([f"+{ ext } " for ext in tosa_extensions ])
316
+ ),
312
317
}
313
318
tosa_version = conftest .get_option ("tosa_version" )
314
319
@@ -406,9 +411,14 @@ def __init__(
406
411
transform_passes : Optional [
407
412
Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
408
413
] = None ,
414
+ tosa_extensions : Optional [List [str ]] = None ,
409
415
):
416
+ if tosa_extensions is None :
417
+ tosa_extensions = []
410
418
tosa_profiles = {
411
- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
419
+ "1.0" : TosaSpecification .create_from_string (
420
+ "TOSA-1.0+FP" + "" .join ([f"+{ ext } " for ext in tosa_extensions ])
421
+ ),
412
422
}
413
423
tosa_version = conftest .get_option ("tosa_version" )
414
424
@@ -655,10 +665,15 @@ def __init__(
655
665
pass_functions : Optional [List [Callable ]] = None ,
656
666
passes_with_exported_program : Optional [List [Type [ExportPass ]]] = None ,
657
667
custom_path : str = None ,
668
+ tosa_extensions : Optional [List [str ]] = None ,
658
669
):
670
+ if tosa_extensions is None :
671
+ tosa_extensions = []
659
672
tosa_profiles = {
660
673
"1.0" : TosaSpecification .create_from_string (
661
- "TOSA-1.0+" + ("INT" if quantize else "FP" )
674
+ "TOSA-1.0+"
675
+ + ("INT" if quantize else "FP" )
676
+ + "" .join ([f"+{ ext } " for ext in tosa_extensions ]),
662
677
),
663
678
}
664
679
tosa_version = conftest .get_option ("tosa_version" )
@@ -721,9 +736,14 @@ def __init__(
721
736
module : torch .nn .Module ,
722
737
test_data : T ,
723
738
custom_path : str = None ,
739
+ tosa_extensions : Optional [List [str ]] = None ,
724
740
):
741
+ if tosa_extensions is None :
742
+ tosa_extensions = []
725
743
tosa_profiles = {
726
- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
744
+ "1.0" : TosaSpecification .create_from_string (
745
+ "TOSA-1.0+INT" + "" .join ([f"+{ ext } " for ext in tosa_extensions ]),
746
+ ),
727
747
}
728
748
tosa_version = conftest .get_option ("tosa_version" )
729
749
@@ -779,18 +799,23 @@ def __init__(
779
799
custom_path : str = None ,
780
800
quantize : Optional [bool ] = False ,
781
801
u55_subset : Optional [bool ] = False ,
802
+ tosa_extensions : Optional [List [str ]] = None ,
782
803
):
804
+ if tosa_extensions is None :
805
+ tosa_extensions = []
783
806
tosa_profiles = {
784
- "1.0" : "TOSA-1.0+" + ("INT" if quantize else "FP" ),
807
+ "1.0" : TosaSpecification .create_from_string (
808
+ "TOSA-1.0+"
809
+ + ("INT" if quantize else "FP" )
810
+ + ("+u55" if u55_subset and quantize else "" )
811
+ + "" .join ([f"+{ ext } " for ext in tosa_extensions ]),
812
+ ),
785
813
}
786
- tosa_version = tosa_profiles [ conftest .get_option ("tosa_version" )]
814
+ tosa_version = conftest .get_option ("tosa_version" )
787
815
788
- if u55_subset and quantize :
789
- tosa_version = f"{ tosa_version } +u55"
816
+ tosa_spec = tosa_profiles [tosa_version ]
790
817
791
- compile_spec = common .get_tosa_compile_spec (
792
- tosa_version , custom_path = custom_path
793
- )
818
+ compile_spec = common .get_tosa_compile_spec (tosa_spec , custom_path = custom_path )
794
819
super ().__init__ (
795
820
module ,
796
821
test_data ,
@@ -799,7 +824,7 @@ def __init__(
799
824
[],
800
825
)
801
826
802
- if "INT" in tosa_version :
827
+ if tosa_spec . support_integer () :
803
828
self .add_stage (self .tester .quantize , pos = 0 )
804
829
805
830
self .change_args ("check_not.exir" , [])
@@ -855,11 +880,16 @@ def __init__(
855
880
transform_passes : Optional [
856
881
Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
857
882
] = None ,
883
+ tosa_extensions : Optional [List [str ]] = None ,
858
884
):
859
885
860
- tosa_profile = TosaSpecification .create_from_string (tosa_version )
886
+ if tosa_extensions is None :
887
+ tosa_extensions = []
888
+ tosa_spec = TosaSpecification .create_from_string (
889
+ tosa_version + "" .join ([f"+{ ext } " for ext in tosa_extensions ])
890
+ )
861
891
compile_spec = common .get_vgf_compile_spec (
862
- tosa_profile , compiler_flags = vgf_compiler_flags , custom_path = custom_path
892
+ tosa_spec , compiler_flags = vgf_compiler_flags , custom_path = custom_path
863
893
)
864
894
865
895
super ().__init__ (
@@ -873,7 +903,7 @@ def __init__(
873
903
transform_passes = transform_passes ,
874
904
)
875
905
876
- if "INT" in tosa_version :
906
+ if tosa_spec . support_integer () :
877
907
quantizer = VgfQuantizer (compile_spec )
878
908
quantization_config = get_symmetric_quantization_config (
879
909
is_per_channel = per_channel_quantization
0 commit comments