diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index fa8951a1e..f02e025ae 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -164,6 +164,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE2(less, Tensor ); OP_DECOMPOSE(linalg_cond); OP_DECOMPOSE(linalg_det); + OP_DECOMPOSE(linalg_lu_factor); OP_DECOMPOSE(linalg_matmul); OP_DECOMPOSE(linalg_svd); OP_DECOMPOSE(matmul); diff --git a/functorch/csrc/BatchRulesLinearAlgebra.cpp b/functorch/csrc/BatchRulesLinearAlgebra.cpp index 2f06af1d4..9b1d84169 100644 --- a/functorch/csrc/BatchRulesLinearAlgebra.cpp +++ b/functorch/csrc/BatchRulesLinearAlgebra.cpp @@ -185,6 +185,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VARIADIC_BDIMS_BOXED(_det_lu_based_helper); VARIADIC_BDIMS_BOXED(_lu_with_info); + VARIADIC_BDIMS_BOXED(linalg_lu_factor_ex); } }} diff --git a/test/functorch_lagging_op_db.py b/test/functorch_lagging_op_db.py index 7befe16d6..f810ae86d 100644 --- a/test/functorch_lagging_op_db.py +++ b/test/functorch_lagging_op_db.py @@ -242,6 +242,8 @@ ('linalg.inv_ex', ''), ('linalg.lstsq', ''), ('linalg.lstsq', 'grad_oriented'), + ('linalg.lu_factor', ''), + ('linalg.lu_factor_ex', ''), ('linalg.matrix_norm', ''), ('linalg.matrix_power', ''), ('linalg.matrix_rank', ''),