Skip to content

Commit 48cdc79

Browse files
authored
Add generic operator implementations.
Differential Revision: D87880917 Pull Request resolved: pytorch#15983
1 parent 09edea4 commit 48cdc79

File tree

9 files changed

+827
-0
lines changed

9 files changed

+827
-0
lines changed
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/generic/operators/op_linalg_svd.h>
10+
11+
#include <algorithm>
12+
#include <cmath>
13+
#include <tuple>
14+
15+
#include <executorch/runtime/core/error.h>
16+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
18+
19+
const float EPSILON = 1e-10;
20+
#ifndef M_PI
21+
#define M_PI 3.14159265358979323846
22+
#endif
23+
24+
namespace impl {
25+
namespace generic {
26+
namespace native {
27+
namespace {
28+
29+
using ::executorch::aten::ScalarType;
30+
using ::executorch::aten::Tensor;
31+
using ::executorch::runtime::Error;
32+
using ::executorch::runtime::KernelRuntimeContext;
33+
34+
// A simple 3x3 matrix struct.
35+
struct Matrix3x3 {
36+
float m[3][3];
37+
};
38+
39+
// Returns the 3x3 identity matrix.
40+
Matrix3x3 identityMatrix() {
41+
Matrix3x3 I{};
42+
for (int i = 0; i < 3; i++) {
43+
for (int j = 0; j < 3; j++) {
44+
I.m[i][j] = (i == j) ? 1.0 : 0.0;
45+
}
46+
}
47+
return I;
48+
}
49+
50+
// Transposes matrix A.
51+
Matrix3x3 transpose(const Matrix3x3& A) {
52+
Matrix3x3 At{};
53+
for (int i = 0; i < 3; i++) {
54+
for (int j = 0; j < 3; j++) {
55+
At.m[i][j] = A.m[j][i];
56+
}
57+
}
58+
return At;
59+
}
60+
61+
// Multiplies matrices A and B.
62+
Matrix3x3 multiply(const Matrix3x3& A, const Matrix3x3& B) {
63+
Matrix3x3 C{};
64+
for (int i = 0; i < 3; i++) {
65+
for (int j = 0; j < 3; j++) {
66+
C.m[i][j] = 0.0;
67+
for (int k = 0; k < 3; k++) {
68+
C.m[i][j] += A.m[i][k] * B.m[k][j];
69+
}
70+
}
71+
}
72+
return C;
73+
}
74+
75+
// Jacobi method to compute the eigen-decomposition of a symmetric 3x3 matrix A.
76+
// It outputs the eigenvalues (in 'diag') and the eigenvectors as columns in V.
77+
void jacobiEigenDecomposition(const Matrix3x3& A, float diag[3], Matrix3x3& V) {
78+
Matrix3x3 D = A; // Make a copy; D will be transformed into a diagonal matrix.
79+
V = identityMatrix();
80+
81+
// Iterate until convergence (or max iterations)
82+
for (int iter = 0; iter < 100; iter++) {
83+
// Find the largest off-diagonal element in D.
84+
int p = 0, q = 1;
85+
float maxOff = std::fabs(D.m[0][1]);
86+
if (std::fabs(D.m[0][2]) > maxOff) {
87+
maxOff = std::fabs(D.m[0][2]);
88+
p = 0;
89+
q = 2;
90+
}
91+
if (std::fabs(D.m[1][2]) > maxOff) {
92+
maxOff = std::fabs(D.m[1][2]);
93+
p = 1;
94+
q = 2;
95+
}
96+
97+
if (maxOff < EPSILON) {
98+
break;
99+
}
100+
101+
// Compute the Jacobi rotation angle.
102+
float theta = 0.0;
103+
if (std::fabs(D.m[p][p] - D.m[q][q]) < EPSILON) {
104+
theta = M_PI / 4.0;
105+
} else {
106+
theta = 0.5 * std::atan2(2 * D.m[p][q], D.m[q][q] - D.m[p][p]);
107+
}
108+
109+
float c = std::cos(theta);
110+
float s = std::sin(theta);
111+
112+
// Update the diagonal elements.
113+
float D_pp = c * c * D.m[p][p] - 2 * s * c * D.m[p][q] + s * s * D.m[q][q];
114+
float D_qq = s * s * D.m[p][p] + 2 * s * c * D.m[p][q] + c * c * D.m[q][q];
115+
D.m[p][p] = D_pp;
116+
D.m[q][q] = D_qq;
117+
D.m[p][q] = D.m[q][p] = 0.0;
118+
119+
// Update the remaining elements.
120+
for (int j = 0; j < 3; j++) {
121+
if (j != p && j != q) {
122+
float D_pj = c * D.m[p][j] - s * D.m[q][j];
123+
float D_qj = s * D.m[p][j] + c * D.m[q][j];
124+
D.m[p][j] = D.m[j][p] = D_pj;
125+
D.m[q][j] = D.m[j][q] = D_qj;
126+
}
127+
}
128+
129+
// Update the eigenvector matrix V.
130+
for (int i = 0; i < 3; i++) {
131+
float V_ip = c * V.m[i][p] - s * V.m[i][q];
132+
float V_iq = s * V.m[i][p] + c * V.m[i][q];
133+
V.m[i][p] = V_ip;
134+
V.m[i][q] = V_iq;
135+
}
136+
}
137+
138+
diag[0] = D.m[0][0];
139+
diag[1] = D.m[1][1];
140+
diag[2] = D.m[2][2];
141+
}
142+
143+
// Sorts the eigenvalues (and the corresponding eigenvectors in V) in descending
144+
// order.
145+
void sortEigenDecomposition(float eigenvalues[3], Matrix3x3& V) {
146+
int indices[3] = {0, 1, 2};
147+
std::sort(indices, indices + 3, [&](int a, int b) {
148+
return eigenvalues[a] > eigenvalues[b];
149+
});
150+
151+
float sortedEigenvalues[3];
152+
Matrix3x3 sortedV{};
153+
for (int i = 0; i < 3; i++) {
154+
sortedEigenvalues[i] = eigenvalues[indices[i]];
155+
for (int j = 0; j < 3; j++) {
156+
sortedV.m[j][i] = V.m[j][indices[i]];
157+
}
158+
}
159+
for (int i = 0; i < 3; i++) {
160+
eigenvalues[i] = sortedEigenvalues[i];
161+
for (int j = 0; j < 3; j++) {
162+
V.m[j][i] = sortedV.m[j][i];
163+
}
164+
}
165+
}
166+
167+
// Computes the cross product of two 3D vectors.
168+
void crossProduct(const float a[3], const float b[3], float result[3]) {
169+
result[0] = a[1] * b[2] - a[2] * b[1];
170+
result[1] = a[2] * b[0] - a[0] * b[2];
171+
result[2] = a[0] * b[1] - a[1] * b[0];
172+
}
173+
174+
// Normalizes a 3D vector.
175+
void normalize(float v[3]) {
176+
float norm = std::sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]);
177+
if (norm > EPSILON) {
178+
v[0] /= norm;
179+
v[1] /= norm;
180+
v[2] /= norm;
181+
}
182+
}
183+
184+
// Computes the singular value decomposition of A such that A = U * S * Vt.
185+
// U and Vt are orthogonal matrices and S is a diagonal matrix with singular
186+
// values.
187+
std::tuple<Matrix3x3, Matrix3x3, Matrix3x3> svd(const Matrix3x3& A) {
188+
// Compute A^T * A (which is symmetric).
189+
Matrix3x3 At = transpose(A);
190+
Matrix3x3 ATA = multiply(At, A);
191+
192+
// Compute the eigen-decomposition of ATA.
193+
float eigenvalues[3];
194+
Matrix3x3 V{};
195+
jacobiEigenDecomposition(ATA, eigenvalues, V);
196+
sortEigenDecomposition(eigenvalues, V);
197+
198+
// The singular values are the square roots of the eigenvalues.
199+
float sigma[3];
200+
for (int i = 0; i < 3; i++) {
201+
sigma[i] = (eigenvalues[i] > 0.0) ? std::sqrt(eigenvalues[i]) : 0.0;
202+
}
203+
204+
// Compute U = A * V * S^{-1}.
205+
Matrix3x3 U{};
206+
for (int i = 0; i < 3; i++) {
207+
float av[3] = {0, 0, 0};
208+
// Multiply A by the i-th eigenvector (column of V).
209+
for (int r = 0; r < 3; r++) {
210+
for (int c = 0; c < 3; c++) {
211+
av[r] += A.m[r][c] * V.m[c][i];
212+
}
213+
}
214+
if (sigma[i] > EPSILON) {
215+
for (int r = 0; r < 3; r++) {
216+
U.m[r][i] = av[r] / sigma[i];
217+
}
218+
} else {
219+
// If sigma[i] is nearly zero, choose a vector orthogonal to the previous
220+
// ones.
221+
float vec[3] = {0, 0, 0};
222+
if (i == 1) {
223+
float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]};
224+
float tmp[3] = {1, 0, 0};
225+
float dot = u0[0] * tmp[0] + u0[1] * tmp[1] + u0[2] * tmp[2];
226+
if (std::fabs(dot) > 0.9) {
227+
tmp[0] = 0;
228+
tmp[1] = 1;
229+
tmp[2] = 0;
230+
}
231+
crossProduct(u0, tmp, vec);
232+
} else if (i == 2) {
233+
float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]};
234+
float u1[3] = {U.m[0][1], U.m[1][1], U.m[2][1]};
235+
crossProduct(u0, u1, vec);
236+
}
237+
normalize(vec);
238+
for (int r = 0; r < 3; r++) {
239+
U.m[r][i] = vec[r];
240+
}
241+
}
242+
}
243+
244+
// Construct the diagonal S matrix.
245+
Matrix3x3 S{};
246+
for (int i = 0; i < 3; i++) {
247+
for (int j = 0; j < 3; j++) {
248+
S.m[i][j] = (i == j) ? sigma[i] : 0.0;
249+
}
250+
}
251+
252+
// Vt is the transpose of V.
253+
Matrix3x3 Vt = transpose(V);
254+
255+
return std::make_tuple(U, S, Vt);
256+
}
257+
} // namespace
258+
259+
std::tuple<Tensor&, Tensor&, Tensor&> linalg_svd_out(
260+
__ET_UNUSED KernelRuntimeContext& ctx,
261+
const Tensor& A,
262+
bool full_matrices,
263+
bool compute_uv,
264+
::executorch::aten::optional<::executorch::aten::string_view> driver,
265+
Tensor& U,
266+
Tensor& S,
267+
Tensor& Vh) {
268+
std::tuple<Tensor&, Tensor&, Tensor&> ret_val(U, S, Vh);
269+
270+
ET_KERNEL_CHECK_MSG(
271+
ctx,
272+
A.scalar_type() == ScalarType::Float,
273+
InvalidArgument,
274+
ret_val,
275+
"input.dtype(): %s must be %s",
276+
::torch::executor::toString(A.scalar_type()),
277+
::torch::executor::toString(ScalarType::Float));
278+
279+
ET_KERNEL_CHECK_MSG(
280+
ctx, A.numel() > 0, InvalidArgument, ret_val, "input.size() must be > 0");
281+
282+
ET_KERNEL_CHECK_MSG(
283+
ctx,
284+
A.numel() % 9 == 0,
285+
InvalidArgument,
286+
ret_val,
287+
"SVD of only 3x3 matrix is supported! Expected the input to have (batch_size x 9) number of elements, but received %d elements instead",
288+
int(A.numel()));
289+
290+
int batch_size = A.numel() / 9;
291+
292+
ET_KERNEL_CHECK_MSG(
293+
ctx,
294+
U.numel() / 9 == batch_size,
295+
InvalidArgument,
296+
ret_val,
297+
"Output tensor U must have the same batch_size as input: %d, but got: %d instead",
298+
batch_size,
299+
int(U.numel() / 9));
300+
301+
ET_KERNEL_CHECK_MSG(
302+
ctx,
303+
S.numel() / 3 == batch_size,
304+
InvalidArgument,
305+
ret_val,
306+
"Output tensor S must have the same batch_size as input: %d, but got: %d instead",
307+
batch_size,
308+
int(S.numel() / 3));
309+
310+
ET_KERNEL_CHECK_MSG(
311+
ctx,
312+
Vh.numel() / 9 == batch_size,
313+
InvalidArgument,
314+
ret_val,
315+
"Output tensor Vh must have the same batch_size as input: %d, but got: %d instead",
316+
batch_size,
317+
int(Vh.numel() / 9));
318+
319+
const float* A_data = A.const_data_ptr<float>();
320+
float* U_data = U.mutable_data_ptr<float>();
321+
float* S_data = S.mutable_data_ptr<float>();
322+
float* Vh_data = Vh.mutable_data_ptr<float>();
323+
324+
for (int i = 0; i < batch_size; i++) {
325+
int offset = i * 9;
326+
Matrix3x3 A_mat = {{
327+
{A_data[offset + 0], A_data[offset + 1], A_data[offset + 2]},
328+
{A_data[offset + 3], A_data[offset + 4], A_data[offset + 5]},
329+
{A_data[offset + 6], A_data[offset + 7], A_data[offset + 8]},
330+
}};
331+
332+
Matrix3x3 U_mat{}, S_mat{}, Vh_mat{};
333+
std::tie(U_mat, S_mat, Vh_mat) = svd(A_mat);
334+
335+
U_data[offset + 0] = U_mat.m[0][0];
336+
U_data[offset + 1] = U_mat.m[0][1];
337+
U_data[offset + 2] = U_mat.m[0][2];
338+
U_data[offset + 3] = U_mat.m[1][0];
339+
U_data[offset + 4] = U_mat.m[1][1];
340+
U_data[offset + 5] = U_mat.m[1][2];
341+
U_data[offset + 6] = U_mat.m[2][0];
342+
U_data[offset + 7] = U_mat.m[2][1];
343+
U_data[offset + 8] = U_mat.m[2][2];
344+
345+
S_data[offset + 0] = S_mat.m[0][0];
346+
S_data[offset + 1] = S_mat.m[1][1];
347+
S_data[offset + 2] = S_mat.m[2][2];
348+
349+
Vh_data[offset + 0] = Vh_mat.m[0][0];
350+
Vh_data[offset + 1] = Vh_mat.m[0][1];
351+
Vh_data[offset + 2] = Vh_mat.m[0][2];
352+
Vh_data[offset + 3] = Vh_mat.m[1][0];
353+
Vh_data[offset + 4] = Vh_mat.m[1][1];
354+
Vh_data[offset + 5] = Vh_mat.m[1][2];
355+
Vh_data[offset + 6] = Vh_mat.m[2][0];
356+
Vh_data[offset + 7] = Vh_mat.m[2][1];
357+
Vh_data[offset + 8] = Vh_mat.m[2][2];
358+
}
359+
360+
return ret_val;
361+
}
362+
363+
} // namespace native
364+
} // namespace generic
365+
} // namespace impl

0 commit comments

Comments
 (0)