Skip to content

Commit 2ce88c9

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.) This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release). Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released. PiperOrigin-RevId: 730928808
1 parent 05614ed commit 2ce88c9

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

jax/_src/lax/linalg.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from jax._src.lax import svd as lax_svd
4242
from jax._src.lax import utils as lax_utils
4343
from jax._src.lax.lax import _float, _complex, _int
44+
from jax._src.lib import version as jaxlib_version
4445
from jax._src.lib.mlir import ir
4546
from jax._src.lib.mlir.dialects import chlo
4647
from jax._src.lib.mlir.dialects import hlo
@@ -2390,12 +2391,17 @@ def _triangular_solve_cpu_lower(
23902391
conjugate_a = False
23912392
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
23922393
target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype)
2393-
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
2394-
alpha_aval = ShapedArray((), a_aval.dtype)
2394+
# TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18.
2395+
if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1):
2396+
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)),
2397+
alpha_aval = ShapedArray((), a_aval.dtype),
2398+
else:
2399+
alpha = ()
2400+
alpha_aval = ()
23952401
rule = _linalg_ffi_lowering(target_name,
2396-
[a_aval, b_aval, alpha_aval],
2402+
[a_aval, b_aval, *alpha_aval],
23972403
operand_output_aliases={1: 0})
2398-
return rule(ctx, a, b, alpha,
2404+
return rule(ctx, a, b, *alpha,
23992405
side=_matrix_side_attr(left_side),
24002406
uplo=_matrix_uplo_attr(lower),
24012407
trans_x=_matrix_transpose_attr(transpose_a, conjugate_a),

jaxlib/cpu/lapack_kernels.cc

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ template struct Trsm<std::complex<double>>;
146146

147147
template <ffi::DataType dtype>
148148
ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
149-
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y, ffi::BufferR0<dtype> alpha,
149+
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y,
150+
// TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after
151+
// the release of JAX 0.5.2.
152+
ffi::RemainingArgs,
150153
ffi::ResultBuffer<dtype> y_out, MatrixParams::Side side,
151154
MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x,
152155
MatrixParams::Diag diag) {
@@ -168,10 +171,10 @@ ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
168171
auto* x_data = x.typed_data();
169172
const int64_t y_out_step{y_rows * y_cols};
170173
const int64_t x_step{x_leading_dim_v * x_leading_dim_v};
174+
ffi::NativeType<dtype> alpha = static_cast<ffi::NativeType<dtype>>(1);
171175
for (int64_t i = 0; i < batch_count; ++i) {
172-
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v,
173-
alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data,
174-
&y_leading_dim_v);
176+
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, &alpha,
177+
x_data, &x_leading_dim_v, y_out_data, &y_leading_dim_v);
175178

176179
y_out_data += y_out_step;
177180
x_data += x_step;
@@ -2241,17 +2244,17 @@ template struct TridiagonalSolver<ffi::DataType::C128>;
22412244

22422245
// FFI Definition Macros (by DataType)
22432246

2244-
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
2245-
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
2246-
name, TriMatrixEquationSolver<data_type>::Kernel, \
2247-
::xla::ffi::Ffi::Bind() \
2248-
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
2249-
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
2250-
.Arg<::xla::ffi::BufferR0<data_type>>(/*alpha*/) \
2251-
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
2252-
.Attr<MatrixParams::Side>("side") \
2253-
.Attr<MatrixParams::UpLo>("uplo") \
2254-
.Attr<MatrixParams::Transpose>("trans_x") \
2247+
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
2248+
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
2249+
name, TriMatrixEquationSolver<data_type>::Kernel, \
2250+
::xla::ffi::Ffi::Bind() \
2251+
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
2252+
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
2253+
.RemainingArgs() \
2254+
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
2255+
.Attr<MatrixParams::Side>("side") \
2256+
.Attr<MatrixParams::UpLo>("uplo") \
2257+
.Attr<MatrixParams::Transpose>("trans_x") \
22552258
.Attr<MatrixParams::Diag>("diag"))
22562259

22572260
#define JAX_CPU_DEFINE_GETRF(name, data_type) \

jaxlib/cpu/lapack_kernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ struct TriMatrixEquationSolver {
147147
inline static FnType* fn = nullptr;
148148
static ::xla::ffi::Error Kernel(
149149
::xla::ffi::Buffer<dtype> x, ::xla::ffi::Buffer<dtype> y,
150-
::xla::ffi::BufferR0<dtype> alpha, ::xla::ffi::ResultBuffer<dtype> y_out,
150+
::xla::ffi::RemainingArgs, ::xla::ffi::ResultBuffer<dtype> y_out,
151151
MatrixParams::Side side, MatrixParams::UpLo uplo,
152152
MatrixParams::Transpose trans_x, MatrixParams::Diag diag);
153153
};

0 commit comments

Comments
 (0)