Skip to content

Commit 1256153

Browse files
Paweł ParuzelGoogle-ML-Automation
authored andcommitted
Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
1 parent 3d9c720 commit 1256153

File tree

9 files changed

+352
-47
lines changed

9 files changed

+352
-47
lines changed

jax/_src/export/_export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,7 @@ def _check_lowering(lowering) -> None:
10171017
"lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi",
10181018
"lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi",
10191019
"lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi",
1020+
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
10201021
]
10211022
# These are the JAX custom call target names that are guaranteed to be stable.
10221023
# Their backwards compatibility is tested by back_compat_test.py.

jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py

Lines changed: 278 additions & 0 deletions
Large diffs are not rendered by default.

jax/_src/lax/linalg.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,6 @@ def _triangular_solve_lowering(
13291329
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
13301330
hlo.TransposeAttr.get(transpose))]
13311331

1332-
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
13331332

13341333
def _triangular_solve_cpu_lower(
13351334
ctx, a, b, *, left_side, lower, transpose_a,
@@ -1342,10 +1341,12 @@ def _triangular_solve_cpu_lower(
13421341
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
13431342
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
13441343
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
1344+
# TODO(b/344892332): Remove the conditional after the compatibility period.
1345+
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
13451346
return lapack.trsm_hlo(
1346-
a_aval.dtype, alpha,
1347-
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
1348-
b_shape_vals=b_shape_vals)
1347+
*ctx_args, a_aval.dtype, alpha,
1348+
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
1349+
b_shape_vals=b_shape_vals)
13491350
else:
13501351
# Fall back to the HLO implementation for unsupported types or batching.
13511352
# TODO: Consider swapping XLA for LAPACK in batched case
@@ -1358,6 +1359,8 @@ def _triangular_solve_cpu_lower(
13581359
ir.BoolAttr.get(unit_diagonal),
13591360
hlo.TransposeAttr.get(transpose))]
13601361

1362+
1363+
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
13611364
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
13621365
platform='cpu')
13631366

jaxlib/cpu/cpu_kernels.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
117117

118118
// FFI Kernels
119119

120-
JAX_CPU_REGISTER_HANDLER(blas_strsm_ffi);
121-
JAX_CPU_REGISTER_HANDLER(blas_dtrsm_ffi);
122-
JAX_CPU_REGISTER_HANDLER(blas_ctrsm_ffi);
123-
JAX_CPU_REGISTER_HANDLER(blas_ztrsm_ffi);
120+
JAX_CPU_REGISTER_HANDLER(lapack_strsm_ffi);
121+
JAX_CPU_REGISTER_HANDLER(lapack_dtrsm_ffi);
122+
JAX_CPU_REGISTER_HANDLER(lapack_ctrsm_ffi);
123+
JAX_CPU_REGISTER_HANDLER(lapack_ztrsm_ffi);
124124
JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
125125
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
126126
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);

jaxlib/cpu/lapack.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,10 @@ nb::dict Registrations() {
234234
dict["lapack_zhetrd"] =
235235
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
236236

237-
dict["blas_strsm_ffi"] = EncapsulateFunction(blas_strsm_ffi);
238-
dict["blas_dtrsm_ffi"] = EncapsulateFunction(blas_dtrsm_ffi);
239-
dict["blas_ctrsm_ffi"] = EncapsulateFunction(blas_ctrsm_ffi);
240-
dict["blas_ztrsm_ffi"] = EncapsulateFunction(blas_ztrsm_ffi);
237+
dict["lapack_strsm_ffi"] = EncapsulateFunction(lapack_strsm_ffi);
238+
dict["lapack_dtrsm_ffi"] = EncapsulateFunction(lapack_dtrsm_ffi);
239+
dict["lapack_ctrsm_ffi"] = EncapsulateFunction(lapack_ctrsm_ffi);
240+
dict["lapack_ztrsm_ffi"] = EncapsulateFunction(lapack_ztrsm_ffi);
241241
dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi);
242242
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
243243
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);

jaxlib/cpu/lapack_kernels.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,10 +2128,10 @@ template struct TridiagonalReduction<ffi::DataType::C128>;
21282128

21292129
// FFI Handlers
21302130

2131-
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
2132-
JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64);
2133-
JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64);
2134-
JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128);
2131+
JAX_CPU_DEFINE_TRSM(lapack_strsm_ffi, ::xla::ffi::DataType::F32);
2132+
JAX_CPU_DEFINE_TRSM(lapack_dtrsm_ffi, ::xla::ffi::DataType::F64);
2133+
JAX_CPU_DEFINE_TRSM(lapack_ctrsm_ffi, ::xla::ffi::DataType::C64);
2134+
JAX_CPU_DEFINE_TRSM(lapack_ztrsm_ffi, ::xla::ffi::DataType::C128);
21352135

21362136
JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
21372137
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);

jaxlib/cpu/lapack_kernels.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,10 @@ struct TridiagonalReduction {
741741
};
742742

743743
// Declare all the handler symbols
744-
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi);
745-
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi);
746-
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ctrsm_ffi);
747-
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ztrsm_ffi);
744+
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_strsm_ffi);
745+
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dtrsm_ffi);
746+
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ctrsm_ffi);
747+
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ztrsm_ffi);
748748
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgetrf_ffi);
749749
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgetrf_ffi);
750750
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgetrf_ffi);

jaxlib/lapack.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -118,52 +118,66 @@ def build_lapack_fn_target(fn_base: str, dtype) -> str:
118118

119119
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
120120
# triangular solve
121-
def trsm_hlo(dtype, alpha, a, b,
121+
def trsm_hlo(ctx, dtype, alpha, a, b,
122122
left_side=False, lower=False, trans_a=False,
123123
conj_a=False, diag=False, *,
124124
b_shape_vals: tuple[DimensionSize, ...]):
125-
_lapack.initialize()
125+
if conj_a and not trans_a:
126+
raise NotImplementedError("Conjugation without transposition not supported")
127+
fn_base = prepare_lapack_call(fn_base="trsm", dtype=dtype)
126128
b_type = ir.RankedTensorType(b.type)
127129

128-
m, n = b_shape_vals[-2:]
129130
batch_dims_vals = b_shape_vals[:-2]
130131
num_bd = len(batch_dims_vals)
131-
batch_size_val = hlo_s32(1)
132-
for b_v in batch_dims_vals:
133-
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
134-
135-
if dtype == np.float32:
136-
fn = "blas_strsm"
137-
elif dtype == np.float64:
138-
fn = "blas_dtrsm"
139-
elif dtype == np.complex64:
140-
fn = "blas_ctrsm"
141-
elif dtype == np.complex128:
142-
fn = "blas_ztrsm"
143-
else:
144-
raise NotImplementedError(f"Unsupported dtype {dtype}")
145-
146-
if conj_a and not trans_a:
147-
raise NotImplementedError("Conjugation without transposition not supported")
148132
scalar_layout = []
149133
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
150134
result_types, result_shapes = mk_result_types_and_shapes(
151135
[(b_shape_vals, b_type.element_type)])
136+
137+
if ctx.is_forward_compat():
138+
# The old TRSM kernel name is prefixed with "blas"
139+
fn = fn_base.replace("lapack", "blas", 1)
140+
m, n = b_shape_vals[-2:]
141+
batch_size_val = hlo_s32(1)
142+
for b_v in batch_dims_vals:
143+
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
144+
result_types, result_shapes = mk_result_types_and_shapes(
145+
[(b_shape_vals, b_type.element_type)]
146+
)
147+
return custom_call(
148+
fn,
149+
result_types=result_types,
150+
operands=[hlo_s32(int(left_side)), hlo_s32(int(lower)),
151+
hlo_s32((2 if conj_a else 1) if trans_a else 0), hlo_s32(int(diag)),
152+
ensure_hlo_s32(m), ensure_hlo_s32(n), batch_size_val,
153+
alpha, a, b],
154+
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
155+
result_layouts=[layout],
156+
operand_output_aliases={9: 0},
157+
result_shapes=result_shapes,
158+
).results
159+
160+
fn = fn_base + "_ffi"
152161
return custom_call(
153162
fn,
154163
result_types=result_types,
155-
operands=[hlo_s32(int(left_side)), hlo_s32(int(lower)),
156-
hlo_s32((2 if conj_a else 1) if trans_a else 0), hlo_s32(int(diag)),
157-
ensure_hlo_s32(m), ensure_hlo_s32(n), batch_size_val,
158-
alpha, a, b],
159-
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
164+
operands=[a, b, alpha],
165+
operand_layouts=[layout] * 2 + [scalar_layout],
160166
result_layouts=[layout],
161-
operand_output_aliases={9: 0},
167+
operand_output_aliases={1: 0},
162168
result_shapes=result_shapes,
169+
backend_config={
170+
"side": _matrix_side_attr(left_side=left_side),
171+
"uplo": _matrix_uplo_attr(lower=lower),
172+
"trans_x": _matrix_transpose_attr(
173+
transpose=trans_a, conjugate=conj_a
174+
),
175+
"diag": _matrix_diagonal_attr(unit_diag=diag),
176+
},
177+
api_version=4,
163178
).results
164179

165180

166-
167181
# ?potrf: Cholesky decomposition
168182

169183
def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False,

tests/export_back_compat_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def test_custom_call_coverage(self):
122122
cpu_eigh_lapack_syev.data_2024_08_19,
123123
cpu_lu_lapack_getrf.data_2024_05_31,
124124
cpu_schur_lapack_gees.data_2024_11_29,
125+
cpu_triangular_solve_blas_trsm.data_2024_12_02,
125126
cpu_svd_lapack_gesdd.data_2024_08_13,
126127
cpu_hessenberg_lapack_gehrd.data_2024_08_31,
127128
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01,
@@ -741,6 +742,14 @@ def check_triangular_solve_results(res_run, res_expected, *, rtol, atol):
741742

742743
self.run_one_test(func, data, rtol=rtol, atol=atol,
743744
check_results=check_triangular_solve_results)
745+
# TODO(b/344892332): Remove the check after the compatibility period.
746+
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
747+
if has_xla_ffi_support:
748+
with config.export_ignore_forward_compatibility(True):
749+
# FFI Kernel test
750+
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
751+
self.run_one_test(func, data, rtol=rtol, atol=atol,
752+
check_results=check_triangular_solve_results)
744753

745754
@parameterized.named_parameters(
746755
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)

0 commit comments

Comments
 (0)