Skip to content

Commit 27db4c3

Browse files
committed
Add sparse arg in one_hot/multi_hot
1 parent 7b58304 commit 27db4c3

File tree

1 file changed

+6
-2
lines changed
  • keras/src/backend/mlx

1 file changed

+6
-2
lines changed

keras/src/backend/mlx/nn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def conv_transpose(
167167
raise NotImplementedError("MLX backend doesn't support conv transpose yet")
168168

169169

170-
def one_hot(x, num_classes, axis=-1, dtype="float32"):
170+
def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
171+
if sparse:
172+
raise ValueError("Unsupported value `sparse=True` with mlx backend")
171173
x = convert_to_tensor(x, dtype=mx.int32)
172174
dtype = to_mlx_dtype(standardize_dtype(dtype))
173175

@@ -181,7 +183,9 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"):
181183
return output
182184

183185

184-
def multi_hot(x, num_classes, axis=-1, dtype="float32"):
186+
def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
187+
if sparse:
188+
raise ValueError("Unsupported value `sparse=True` with mlx backend")
185189
x = convert_to_tensor(x)
186190
reduction_axis = 1 if x.ndim > 1 else 0
187191
return one_hot(x, num_classes, axis=axis, dtype=dtype).max(

0 commit comments

Comments
 (0)