@@ -118,52 +118,66 @@ def build_lapack_fn_target(fn_base: str, dtype) -> str:
118118
119119# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
120120# triangular solve
121- def trsm_hlo (dtype , alpha , a , b ,
121+ def trsm_hlo (ctx , dtype , alpha , a , b ,
122122 left_side = False , lower = False , trans_a = False ,
123123 conj_a = False , diag = False , * ,
124124 b_shape_vals : tuple [DimensionSize , ...]):
125- _lapack .initialize ()
125+ if conj_a and not trans_a :
126+ raise NotImplementedError ("Conjugation without transposition not supported" )
127+ fn_base = prepare_lapack_call (fn_base = "trsm" , dtype = dtype )
126128 b_type = ir .RankedTensorType (b .type )
127129
128- m , n = b_shape_vals [- 2 :]
129130 batch_dims_vals = b_shape_vals [:- 2 ]
130131 num_bd = len (batch_dims_vals )
131- batch_size_val = hlo_s32 (1 )
132- for b_v in batch_dims_vals :
133- batch_size_val = hlo .multiply (batch_size_val , ensure_hlo_s32 (b_v ))
134-
135- if dtype == np .float32 :
136- fn = "blas_strsm"
137- elif dtype == np .float64 :
138- fn = "blas_dtrsm"
139- elif dtype == np .complex64 :
140- fn = "blas_ctrsm"
141- elif dtype == np .complex128 :
142- fn = "blas_ztrsm"
143- else :
144- raise NotImplementedError (f"Unsupported dtype { dtype } " )
145-
146- if conj_a and not trans_a :
147- raise NotImplementedError ("Conjugation without transposition not supported" )
148132 scalar_layout = []
149133 layout = (num_bd , num_bd + 1 ) + tuple (range (num_bd - 1 , - 1 , - 1 ))
150134 result_types , result_shapes = mk_result_types_and_shapes (
151135 [(b_shape_vals , b_type .element_type )])
136+
137+ if ctx .is_forward_compat ():
138+ # The old TRSM kernel name is prefixed with "blas"
139+ fn = fn_base .replace ("lapack" , "blas" , 1 )
140+ m , n = b_shape_vals [- 2 :]
141+ batch_size_val = hlo_s32 (1 )
142+ for b_v in batch_dims_vals :
143+ batch_size_val = hlo .multiply (batch_size_val , ensure_hlo_s32 (b_v ))
144+ result_types , result_shapes = mk_result_types_and_shapes (
145+ [(b_shape_vals , b_type .element_type )]
146+ )
147+ return custom_call (
148+ fn ,
149+ result_types = result_types ,
150+ operands = [hlo_s32 (int (left_side )), hlo_s32 (int (lower )),
151+ hlo_s32 ((2 if conj_a else 1 ) if trans_a else 0 ), hlo_s32 (int (diag )),
152+ ensure_hlo_s32 (m ), ensure_hlo_s32 (n ), batch_size_val ,
153+ alpha , a , b ],
154+ operand_layouts = [scalar_layout ] * 8 + [layout ] * 2 ,
155+ result_layouts = [layout ],
156+ operand_output_aliases = {9 : 0 },
157+ result_shapes = result_shapes ,
158+ ).results
159+
160+ fn = fn_base + "_ffi"
152161 return custom_call (
153162 fn ,
154163 result_types = result_types ,
155- operands = [hlo_s32 (int (left_side )), hlo_s32 (int (lower )),
156- hlo_s32 ((2 if conj_a else 1 ) if trans_a else 0 ), hlo_s32 (int (diag )),
157- ensure_hlo_s32 (m ), ensure_hlo_s32 (n ), batch_size_val ,
158- alpha , a , b ],
159- operand_layouts = [scalar_layout ] * 8 + [layout ] * 2 ,
164+ operands = [a , b , alpha ],
165+ operand_layouts = [layout ] * 2 + [scalar_layout ],
160166 result_layouts = [layout ],
161- operand_output_aliases = {9 : 0 },
167+ operand_output_aliases = {1 : 0 },
162168 result_shapes = result_shapes ,
169+ backend_config = {
170+ "side" : _matrix_side_attr (left_side = left_side ),
171+ "uplo" : _matrix_uplo_attr (lower = lower ),
172+ "trans_x" : _matrix_transpose_attr (
173+ transpose = trans_a , conjugate = conj_a
174+ ),
175+ "diag" : _matrix_diagonal_attr (unit_diag = diag ),
176+ },
177+ api_version = 4 ,
163178 ).results
164179
165180
166-
167181# ?potrf: Cholesky decomposition
168182
169183def potrf_hlo (ctx , dtype , a : ir .Value , * , lower = False ,
0 commit comments