@@ -277,3 +277,80 @@ def stack(*items):
277277 return keras .ops .stack (items , axis = axis )
278278
279279 return keras .tree .map_structure (stack , * structures )
280+
281+
282+ def fill_triangular_matrix (x : Tensor , upper : bool = False , positive_diag : bool = False ):
283+ """
284+ Reshapes a batch of matrix entries into a triangular matrix (either upper or lower).
285+
286+ Note: If final axis has length 1, this simply reshapes to (batch_size, 1, 1) and optionally applies softplus.
287+
288+ Parameters
289+ ----------
290+ x : Tensor of shape (batch_size, m)
291+ Batch of flattened nonzero matrix elements for triangular matrix.
292+ upper : bool
293+ Return upper triangular matrix if True, else lower triangular matrix. Default is False.
294+ positive_diag : bool
295+ Whether to apply a softplus operation to diagonal elements. Default is False.
296+
297+ Returns
298+ -------
299+ Tensor of shape (batch_size, n, n)
300+ Batch of triangular matrices with m = n * (n + 1) / 2 unique nonzero elements.
301+
302+ Raises
303+ ------
304+ ValueError
305+ If provided nonzero elements do not correspond to possible triangular matrix shape
306+ (n,n) with n = sqrt( 1/4 + 2 * m) - 1/2 due to m = n * (n + 1) / 2.
307+ """
308+ batch_shape = x .shape [:- 1 ]
309+ m = x .shape [- 1 ]
310+
311+ if m == 1 :
312+ y = keras .ops .reshape (x , (- 1 , 1 , 1 ))
313+ if positive_diag :
314+ y = keras .activations .softplus (y )
315+ return y
316+
317+ # Calculate matrix shape
318+ n = (0.25 + 2 * m ) ** 0.5 - 0.5
319+ if not np .isclose (np .floor (n ), n ):
320+ raise ValueError (f"Input right-most shape ({ m } ) does not correspond to a triangular matrix." )
321+ else :
322+ n = int (n )
323+
324+ # Trick: Create triangular matrix by concatenating with a flipped version of its tail, then reshape.
325+ x_tail = keras .ops .take (x , indices = list (range ((m - (n ** 2 - m )), x .shape [- 1 ])), axis = - 1 )
326+ if not upper :
327+ y = keras .ops .concatenate ([x_tail , keras .ops .flip (x , axis = - 1 )], axis = len (batch_shape ))
328+ y = keras .ops .reshape (y , (- 1 , n , n ))
329+ y = keras .ops .tril (y ) # TODO: fails with tensorflow
330+
331+ if positive_diag :
332+ y_offdiag = keras .ops .tril (y , k = - 1 )
333+ y_diag = keras .ops .tril (
334+ keras .ops .triu ( # carve out diagonal, by setting upper and lower offdiagonals to zero
335+ keras .activations .softplus (y )
336+ ), # apply softplus to enforce positivity
337+ )
338+ y = y_diag + y_offdiag
339+
340+ else :
341+ y = keras .ops .concatenate ([x , keras .ops .flip (x_tail , axis = - 1 )], axis = len (batch_shape ))
342+ y = keras .ops .reshape (y , (- 1 , n , n ))
343+ y = keras .ops .triu (
344+ y ,
345+ )
346+
347+ if positive_diag :
348+ y_offdiag = keras .ops .triu (y , k = 1 )
349+ y_diag = keras .ops .tril (
350+ keras .ops .triu ( # carve out diagonal, by setting upper and lower offdiagonals to zero
351+ keras .activations .softplus (y )
352+ ), # apply softplus to enforce positivity
353+ )
354+ y = y_diag + y_offdiag
355+
356+ return y
0 commit comments