@@ -3079,3 +3079,135 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
30793079 self .assertIn (
30803080 "Hidden dimension must be a multiple of 4" , str (context .exception )
30813081 )
3082+
3083+ @expand (
3084+ [
3085+ (
3086+ "basic_int8_dim_1" ,
3087+ torch .tensor ([[10 , 20 , 30 ]], dtype = torch .int8 ),
3088+ None ,
3089+ 1 ,
3090+ 0.1 ,
3091+ 0 ,
3092+ 0.004 ,
3093+ 0 ,
3094+ torch .int8 ,
3095+ torch .tensor ([[23 , 61 , 127 ]], dtype = torch .int8 ),
3096+ ),
3097+ (
3098+ "uint8_with_zero_points" ,
3099+ torch .tensor ([[128 , 130 , 132 ]], dtype = torch .uint8 ),
3100+ None ,
3101+ 1 ,
3102+ 0.1 ,
3103+ 128 ,
3104+ 0.004 ,
3105+ 128 ,
3106+ torch .uint8 ,
3107+ torch .tensor ([[195 , 210 , 228 ]], dtype = torch .uint8 ),
3108+ ),
3109+ (
3110+ "basic_int16" ,
3111+ torch .tensor ([[100 , 200 , 300 ]], dtype = torch .int16 ),
3112+ None ,
3113+ 1 ,
3114+ 0.01 ,
3115+ 0 ,
3116+ 0.004 ,
3117+ 0 ,
3118+ torch .int16 ,
3119+ torch .tensor ([[23 , 61 , 166 ]], dtype = torch .int16 ),
3120+ ),
3121+ (
3122+ "multi_row_int8" ,
3123+ torch .tensor ([[10 , 20 , 30 ], [5 , 10 , 15 ]], dtype = torch .int8 ),
3124+ None ,
3125+ 1 ,
3126+ 0.1 ,
3127+ 0 ,
3128+ 0.004 ,
3129+ 0 ,
3130+ torch .int8 ,
3131+ torch .tensor ([[23 , 61 , 127 ], [47 , 77 , 127 ]], dtype = torch .int8 ),
3132+ ),
3133+ (
3134+ "softmax_dim_0" ,
3135+ torch .tensor ([[10 , 20 ], [30 , 40 ]], dtype = torch .int8 ),
3136+ None ,
3137+ 0 ,
3138+ 0.1 ,
3139+ 0 ,
3140+ 0.004 ,
3141+ 0 ,
3142+ torch .int8 ,
3143+ torch .tensor ([[30 , 30 ], [127 , 127 ]], dtype = torch .int8 ),
3144+ ),
3145+ ]
3146+ )
3147+ def test_quantized_softmax_per_tensor (
3148+ self ,
3149+ name : str ,
3150+ input_tensor : torch .Tensor ,
3151+ mask : torch .Tensor | None ,
3152+ dim : int ,
3153+ in_scale : float ,
3154+ in_zero_point : int ,
3155+ out_scale : float ,
3156+ out_zero_point : int ,
3157+ dtype : torch .dtype ,
3158+ expected_output : torch .Tensor ,
3159+ ) -> None :
3160+ output = torch .ops .cadence .quantized_softmax .per_tensor (
3161+ input_tensor ,
3162+ mask ,
3163+ dim ,
3164+ in_scale ,
3165+ in_zero_point ,
3166+ out_scale ,
3167+ out_zero_point ,
3168+ )
3169+
3170+ # Verify output properties
3171+ self .assertEqual (
3172+ output .dtype , dtype , f"Output dtype should be { dtype } in { name } "
3173+ )
3174+ self .assertEqual (
3175+ output .shape ,
3176+ input_tensor .shape ,
3177+ f"Output shape should match input shape in { name } " ,
3178+ )
3179+
3180+ # Verify output matches expected values (allowing for small quantization errors)
3181+ # For softmax, we expect outputs to be in [0, 1] range when dequantized
3182+ self .assertTrue (
3183+ torch .allclose (
3184+ output .to (torch .float32 ),
3185+ expected_output .to (torch .float32 ),
3186+ rtol = 0.05 ,
3187+ atol = 5.0 ,
3188+ ),
3189+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
3190+ )
3191+
3192+ def test_quantized_softmax (self ) -> None :
3193+ # Test quantized_softmax (default variant with tensor scale/zero_point)
3194+ input_tensor = torch .tensor ([[10 , 20 , 30 ]], dtype = torch .int8 )
3195+ in_scale = torch .tensor ([0.1 ])
3196+ in_zero_point = torch .tensor ([0 ])
3197+ output = torch .ops .cadence .quantized_softmax (
3198+ input_tensor ,
3199+ None , # mask
3200+ 1 , # dim
3201+ in_scale ,
3202+ in_zero_point ,
3203+ 0.004 , # out_scale
3204+ 0 , # out_zero_point
3205+ )
3206+
3207+ # Verify output properties
3208+ self .assertEqual (output .dtype , torch .int8 , "Output dtype should be int8" )
3209+ self .assertEqual (
3210+ output .shape ,
3211+ input_tensor .shape ,
3212+ "Output shape should match input shape" ,
3213+ )
0 commit comments