diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 551b23a90be..d7ec5bf05b3 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -60,7 +60,6 @@ def _validate_ref_impl_exists() -> None: "cadence::quantized_softmax.per_tensor", "cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove "cadence::quantized_relu", # We should only support per_tensor variant, should remove - "cadence::linalg_svd", "cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove "cadence::quantized_softmax", "cadence::quantized_w8a32_gru", diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 53cb0845f42..f9f7249b249 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -21,11 +21,13 @@ # Registry to track all ops with reference implementations _REGISTERED_REF_IMPLEMENTATIONS: set[str] = set() +_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...] + # Custom impl wrapper that tracks registrations def impl_tracked( lib: Library, op_name: str -) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: +) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]: """Wrapper around impl that tracks registered ops.""" _REGISTERED_REF_IMPLEMENTATIONS.add(op_name) return impl(lib, op_name) @@ -312,7 +314,7 @@ def quantized_add_per_tensor( dequant_Y = Y_scale * (Y - Y_zero_point) # q_min/q_max are unused args - return quantize_per_tensor( + out = quantize_per_tensor( dequant_X + dequant_Y, out_scale, out_zero_point, @@ -321,6 +323,9 @@ def quantized_add_per_tensor( dtype, ) + assert isinstance(out, torch.Tensor) + return out + @impl_tracked(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") def quantized_add_asym8sxasym8s_asym8s_per_tensor( @@ -338,9 +343,11 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor( if Y.dtype != torch.int8: raise ValueError("Y dtype must be torch.int8") - return quantized_add_per_tensor( + out = quantized_add_per_tensor( X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point ) + assert isinstance(out, torch.Tensor) + return out @impl_tracked(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") @@ -359,9 +366,11 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor( if Y.dtype != torch.uint8: raise ValueError("Y dtype must be torch.int8") - return quantized_add_per_tensor( + out = quantized_add_per_tensor( X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point ) + assert isinstance(out, torch.Tensor) + return out def quantized_linear_common( @@ -407,14 +416,16 @@ def quantized_linear_common( (weight - weight_zero_point).float(), bias.float(), ) - return quantize_per_tensor( + out = quantize_per_tensor( out, out_scale, out_zero_point, torch.iinfo(dtype).min, torch.iinfo(dtype).max, dtype, - ).reshape(*leading_dims, N) + ) + assert isinstance(out, torch.Tensor) + return out.reshape(*leading_dims, N) def quantized_linear_variant( @@ -576,7 +587,7 @@ def quantized_matmul( (X - X_zero_point).float(), (Y - Y_zero_point).float(), ) - return quantize_per_tensor( + out = quantize_per_tensor( out, out_scale, out_zero_point, @@ -584,6 +595,8 @@ def quantized_matmul( torch.iinfo(X.dtype).max, X.dtype, ) + assert isinstance(out, torch.Tensor) + return out @impl_tracked(m, "quantized_matmul_asym8sxasym8s_asym8s") @@ -603,7 +616,7 @@ def quantized_matmul_asym8sxasym8s_asym8s( if Y.dtype != torch.int8: raise ValueError("Y dtype must be torch.int8") - return quantized_matmul( + out = quantized_matmul( X, X_zero_point, Y, @@ -614,6 +627,8 @@ def quantized_matmul_asym8sxasym8s_asym8s( out_zero_point, transposed, ) + assert isinstance(out, torch.Tensor) + return out @impl_tracked(m, "quantized_matmul_asym8uxasym8u_asym8u") @@ -633,7 +648,7 @@ def quantized_matmul_asym8uxasym8u_asym8u( if Y.dtype != torch.uint8: raise ValueError("Y dtype must be torch.uint8") - return quantized_matmul( + out = quantized_matmul( X, X_zero_point, Y, @@ -644,6 +659,8 @@ def quantized_matmul_asym8uxasym8u_asym8u( out_zero_point, transposed, ) + assert isinstance(out, torch.Tensor) + return out @impl_tracked(m, "quantized_layer_norm.per_tensor") @@ -681,11 +698,12 @@ def quantized_layer_norm_per_tensor( float_input_tensor = dequantize_per_tensor( input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype ) + assert isinstance(float_input_tensor, torch.Tensor) out = torch.nn.functional.layer_norm( float_input_tensor, normalized_shape, weight, bias, eps=eps ) - return quantize_per_tensor( + out = quantize_per_tensor( out, output_scale, output_zero_point, @@ -693,6 +711,8 @@ def quantized_layer_norm_per_tensor( torch.iinfo(input_tensor.dtype).max, input_tensor.dtype, ) + assert isinstance(out, torch.Tensor) + return out def quantized_conv_per_tensor( @@ -754,7 +774,7 @@ def quantized_conv_per_tensor( else: raise ValueError("Input tensor must be 3D or 4D") - return quantize_per_tensor( + out = quantize_per_tensor( float_out, output_scale, output_zero_point, @@ -762,6 +782,8 @@ def quantized_conv_per_tensor( torch.iinfo(input_tensor.dtype).max, input_tensor.dtype, ) + assert isinstance(out, torch.Tensor) + return out @impl_tracked(m, "quantized_conv2d_nchw.per_tensor") @@ -983,7 +1005,7 @@ def variant( # Call the appropriate base function match layout: case "nchw": - return quantized_conv2d_nchw_per_tensor( + out = quantized_conv2d_nchw_per_tensor( input_tensor, weight, bias, @@ -1000,7 +1022,7 @@ def variant( out_shift, ) case "nhwc": - return quantized_conv2d_nhwc_per_tensor( + out = quantized_conv2d_nhwc_per_tensor( input_tensor, weight, bias, @@ -1019,6 +1041,9 @@ def variant( case _: raise ValueError(f"Unknown layout {layout}") + assert isinstance(out, torch.Tensor) + return out + return variant return decorator @@ -1293,7 +1318,7 @@ def quantized_relu_common( dequantized_X = torch.where( X > X_zero_point, X - X_zero_point, torch.zeros_like(X) ).to(torch.float32) - return quantize_per_tensor( + out = quantize_per_tensor( dequantized_X, out_scale, out_zero_point, @@ -1301,6 +1326,8 @@ def quantized_relu_common( torch.iinfo(X.dtype).max, X.dtype, ) + assert isinstance(out, torch.Tensor) + return out def quantized_relu_variant( @@ -1557,7 +1584,7 @@ def im2row_per_tensor( in_zero_point: int, channel_last: bool = False, ) -> torch.Tensor: - return im2row( + out = im2row( input_tensor, kernel_size, dilation, @@ -1566,6 +1593,8 @@ def im2row_per_tensor( torch.tensor(in_zero_point, dtype=torch.int32), channel_last, ) + assert isinstance(out, torch.Tensor) + return out @impl_tracked(m, "transposed_im2row") @@ -1773,3 +1802,15 @@ def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.T @impl_tracked(m, "idma_wait") def idma_wait(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor: return src.clone() + + +@impl_tracked(m, "linalg_svd") +def linalg_svd( + A: torch.Tensor, + full_matrices: bool = False, + compute_uv: bool = True, + driver: str | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert compute_uv + U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) + return U.contiguous(), S.contiguous(), Vh.contiguous() diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 8d910d29e52..b3886f453f5 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2632,3 +2632,34 @@ def test_quantized_embedding_byte( expected_out, ) ) + + @expand( + [ + *[ + ( + dtype, + (4, 4), + full_matrices, + ) + for dtype in [torch.float32, torch.float64] + for full_matrices in [True, False] + ] + ] + ) + def test_linalg_svd_outputs_are_contiguous( + self, + dtype: torch.dtype, + shape: tuple[int, int], + full_matrices: bool, + ) -> None: + m, n = shape + a = torch.eye(m, n, dtype=dtype) + + U, S, Vh = torch.ops.cadence.linalg_svd(a, full_matrices) + + self.assertTrue(U.is_contiguous(), "U not contiguous") + self.assertTrue(S.is_contiguous(), "S not contiguous") + self.assertTrue(Vh.is_contiguous(), "Vh not contiguous") + self.assertTrue(U.dtype == dtype, "U dtype mismatch") + self.assertTrue(S.dtype == dtype, "S dtype mismatch") + self.assertTrue(Vh.dtype == dtype, "Vh dtype mismatch")