|
5 | 5 |
|
6 | 6 | from pytensor.link.mlx.dispatch.basic import mlx_funcify
|
7 | 7 | from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
|
8 |
| -from pytensor.scalar import Softplus |
9 | 8 | from pytensor.scalar.basic import (
|
10 | 9 | AND,
|
| 10 | + EQ, |
| 11 | + GE, |
| 12 | + GT, |
| 13 | + LE, |
| 14 | + LT, |
| 15 | + NEQ, |
11 | 16 | OR,
|
| 17 | + Abs, |
12 | 18 | Add,
|
13 | 19 | Cast,
|
| 20 | + Cos, |
| 21 | + Exp, |
| 22 | + IntDiv, |
| 23 | + Invert, |
| 24 | + IsInf, |
| 25 | + IsNan, |
| 26 | + Log, |
| 27 | + Log1p, |
14 | 28 | Mul,
|
| 29 | + Neg, |
| 30 | + Pow, |
15 | 31 | ScalarMaximum,
|
16 | 32 | ScalarMinimum,
|
| 33 | + Sign, |
| 34 | + Sin, |
| 35 | + Sqr, |
| 36 | + Sqrt, |
| 37 | + Sub, |
| 38 | + Switch, |
| 39 | + TrueDiv, |
17 | 40 | )
|
18 |
| -from pytensor.tensor.elemwise import CAReduce, DimShuffle |
| 41 | +from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus |
| 42 | +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
19 | 43 | from pytensor.tensor.special import Softmax, SoftmaxGrad
|
20 | 44 |
|
21 | 45 |
|
@@ -50,23 +74,23 @@ def mlx_funcify_CAReduce(op, **kwargs):
|
50 | 74 |
|
51 | 75 |
|
52 | 76 | @mlx_funcify_CAReduce_scalar_op.register(Add)
|
53 |
| -def mlx_funcify_Elemwise_scalar_Add(scalar_op, axis): |
| 77 | +def mlx_funcify_CAReduce_scalar_Add(scalar_op, axis): |
54 | 78 | def sum_reduce(x):
|
55 | 79 | return mx.sum(x, axis=axis)
|
56 | 80 |
|
57 | 81 | return sum_reduce
|
58 | 82 |
|
59 | 83 |
|
60 | 84 | @mlx_funcify_CAReduce_scalar_op.register(Mul)
|
61 |
| -def mlx_funcify_Elemwise_scalar_Mul(scalar_op, axis): |
| 85 | +def mlx_funcify_CAReduce_scalar_Mul(scalar_op, axis): |
62 | 86 | def prod_reduce(x):
|
63 | 87 | return mx.prod(x, axis=axis)
|
64 | 88 |
|
65 | 89 | return prod_reduce
|
66 | 90 |
|
67 | 91 |
|
68 | 92 | @mlx_funcify_CAReduce_scalar_op.register(AND)
|
69 |
| -def mlx_funcify_Elemwise_scalar_AND(scalar_op, axis): |
| 93 | +def mlx_funcify_CAReduce_scalar_AND(scalar_op, axis): |
70 | 94 | def all_reduce(x):
|
71 | 95 | return x.all(axis=axis)
|
72 | 96 |
|
@@ -164,3 +188,259 @@ def cast(x):
|
164 | 188 | raise
|
165 | 189 |
|
166 | 190 | return cast
|
| 191 | + |
| 192 | + |
| 193 | +@singledispatch |
| 194 | +def mlx_funcify_Elemwise_scalar_op(scalar_op): |
| 195 | + """Simplified implementation for MLX scalar operations.""" |
| 196 | + |
| 197 | + # Try using the operation name directly (most common case) |
| 198 | + op_name = getattr(scalar_op, "name", None) |
| 199 | + if op_name is not None: |
| 200 | + try: |
| 201 | + mlx_func = getattr(mx, op_name) |
| 202 | + # Handle variadic functions like Add |
| 203 | + if hasattr(scalar_op, "inputs") and len(scalar_op.inputs) > 2: |
| 204 | + |
| 205 | + def variadic_func(*args): |
| 206 | + result = args[0] |
| 207 | + for arg in args[1:]: |
| 208 | + result = mlx_func(result, arg) |
| 209 | + return result |
| 210 | + |
| 211 | + return variadic_func |
| 212 | + else: |
| 213 | + return mlx_func |
| 214 | + except AttributeError: |
| 215 | + pass |
| 216 | + |
| 217 | + raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}") |
| 218 | + |
| 219 | + |
| 220 | +@mlx_funcify_Elemwise_scalar_op.register(Add) |
| 221 | +def mlx_funcify_Elemwise_scalar_Add(scalar_op): |
| 222 | + def add(*args): |
| 223 | + result = args[0] |
| 224 | + for arg in args[1:]: |
| 225 | + result = mx.add(result, arg) |
| 226 | + return result |
| 227 | + |
| 228 | + return add |
| 229 | + |
| 230 | + |
| 231 | +@mlx_funcify_Elemwise_scalar_op.register(Sub) |
| 232 | +def mlx_funcify_Elemwise_scalar_Sub(scalar_op): |
| 233 | + return mx.subtract |
| 234 | + |
| 235 | + |
| 236 | +@mlx_funcify_Elemwise_scalar_op.register(Mul) |
| 237 | +def mlx_funcify_Elemwise_scalar_Mul(scalar_op): |
| 238 | + def mul(*args): |
| 239 | + result = args[0] |
| 240 | + for arg in args[1:]: |
| 241 | + result = mx.multiply(result, arg) |
| 242 | + return result |
| 243 | + |
| 244 | + return mul |
| 245 | + |
| 246 | + |
| 247 | +@mlx_funcify_Elemwise_scalar_op.register(TrueDiv) |
| 248 | +def mlx_funcify_Elemwise_scalar_TrueDiv(scalar_op): |
| 249 | + return mx.divide |
| 250 | + |
| 251 | + |
| 252 | +@mlx_funcify_Elemwise_scalar_op.register(IntDiv) |
| 253 | +def mlx_funcify_Elemwise_scalar_IntDiv(scalar_op): |
| 254 | + return mx.floor_divide |
| 255 | + |
| 256 | + |
| 257 | +@mlx_funcify_Elemwise_scalar_op.register(Pow) |
| 258 | +def mlx_funcify_Elemwise_scalar_Pow(scalar_op): |
| 259 | + return mx.power |
| 260 | + |
| 261 | + |
| 262 | +@mlx_funcify_Elemwise_scalar_op.register(Exp) |
| 263 | +def mlx_funcify_Elemwise_scalar_Exp(scalar_op): |
| 264 | + return mx.exp |
| 265 | + |
| 266 | + |
| 267 | +@mlx_funcify_Elemwise_scalar_op.register(Log) |
| 268 | +def mlx_funcify_Elemwise_scalar_Log(scalar_op): |
| 269 | + return mx.log |
| 270 | + |
| 271 | + |
| 272 | +@mlx_funcify_Elemwise_scalar_op.register(Log1p) |
| 273 | +def mlx_funcify_Elemwise_scalar_Log1p(scalar_op): |
| 274 | + return mx.log1p |
| 275 | + |
| 276 | + |
| 277 | +@mlx_funcify_Elemwise_scalar_op.register(Sin) |
| 278 | +def mlx_funcify_Elemwise_scalar_Sin(scalar_op): |
| 279 | + return mx.sin |
| 280 | + |
| 281 | + |
| 282 | +@mlx_funcify_Elemwise_scalar_op.register(Cos) |
| 283 | +def mlx_funcify_Elemwise_scalar_Cos(scalar_op): |
| 284 | + return mx.cos |
| 285 | + |
| 286 | + |
| 287 | +@mlx_funcify_Elemwise_scalar_op.register(Sqrt) |
| 288 | +def mlx_funcify_Elemwise_scalar_Sqrt(scalar_op): |
| 289 | + return mx.sqrt |
| 290 | + |
| 291 | + |
| 292 | +@mlx_funcify_Elemwise_scalar_op.register(Sqr) |
| 293 | +def mlx_funcify_Elemwise_scalar_Sqr(scalar_op): |
| 294 | + return mx.square |
| 295 | + |
| 296 | + |
| 297 | +@mlx_funcify_Elemwise_scalar_op.register(Abs) |
| 298 | +def mlx_funcify_Elemwise_scalar_Abs(scalar_op): |
| 299 | + return mx.abs |
| 300 | + |
| 301 | + |
| 302 | +@mlx_funcify_Elemwise_scalar_op.register(Neg) |
| 303 | +def mlx_funcify_Elemwise_scalar_Neg(scalar_op): |
| 304 | + return mx.negative |
| 305 | + |
| 306 | + |
| 307 | +@mlx_funcify_Elemwise_scalar_op.register(Sign) |
| 308 | +def mlx_funcify_Elemwise_scalar_Sign(scalar_op): |
| 309 | + return mx.sign |
| 310 | + |
| 311 | + |
| 312 | +@mlx_funcify_Elemwise_scalar_op.register(LE) |
| 313 | +def mlx_funcify_Elemwise_scalar_LE(scalar_op): |
| 314 | + return mx.less_equal |
| 315 | + |
| 316 | + |
| 317 | +@mlx_funcify_Elemwise_scalar_op.register(LT) |
| 318 | +def mlx_funcify_Elemwise_scalar_LT(scalar_op): |
| 319 | + return mx.less |
| 320 | + |
| 321 | + |
| 322 | +@mlx_funcify_Elemwise_scalar_op.register(GE) |
| 323 | +def mlx_funcify_Elemwise_scalar_GE(scalar_op): |
| 324 | + return mx.greater_equal |
| 325 | + |
| 326 | + |
| 327 | +@mlx_funcify_Elemwise_scalar_op.register(GT) |
| 328 | +def mlx_funcify_Elemwise_scalar_GT(scalar_op): |
| 329 | + return mx.greater |
| 330 | + |
| 331 | + |
| 332 | +@mlx_funcify_Elemwise_scalar_op.register(EQ) |
| 333 | +def mlx_funcify_Elemwise_scalar_EQ(scalar_op): |
| 334 | + return mx.equal |
| 335 | + |
| 336 | + |
| 337 | +@mlx_funcify_Elemwise_scalar_op.register(NEQ) |
| 338 | +def mlx_funcify_Elemwise_scalar_NEQ(scalar_op): |
| 339 | + return mx.not_equal |
| 340 | + |
| 341 | + |
| 342 | +@mlx_funcify_Elemwise_scalar_op.register(Switch) |
| 343 | +def mlx_funcify_Elemwise_scalar_Switch(scalar_op): |
| 344 | + return mx.where |
| 345 | + |
| 346 | + |
| 347 | +@mlx_funcify_Elemwise_scalar_op.register(AND) |
| 348 | +def mlx_funcify_Elemwise_scalar_AND(scalar_op): |
| 349 | + return mx.bitwise_and |
| 350 | + |
| 351 | + |
| 352 | +@mlx_funcify_Elemwise_scalar_op.register(OR) |
| 353 | +def mlx_funcify_Elemwise_scalar_OR(scalar_op): |
| 354 | + return mx.bitwise_or |
| 355 | + |
| 356 | + |
| 357 | +@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum) |
| 358 | +def mlx_funcify_Elemwise_scalar_ScalarMaximum(scalar_op): |
| 359 | + return mx.maximum |
| 360 | + |
| 361 | + |
| 362 | +@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum) |
| 363 | +def mlx_funcify_Elemwise_scalar_ScalarMinimum(scalar_op): |
| 364 | + return mx.minimum |
| 365 | + |
| 366 | + |
| 367 | +@mlx_funcify_Elemwise_scalar_op.register(Cast) |
| 368 | +def mlx_funcify_Elemwise_scalar_Cast(scalar_op): |
| 369 | + def cast(x): |
| 370 | + dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype) |
| 371 | + try: |
| 372 | + return x.astype(dtype) |
| 373 | + except ValueError as e: |
| 374 | + if "is not supported on the GPU" in str(e): |
| 375 | + import warnings |
| 376 | + |
| 377 | + warnings.warn( |
| 378 | + f"MLX GPU limitation: {e}. Attempting automatic fallback casting.", |
| 379 | + UserWarning, |
| 380 | + stacklevel=2, |
| 381 | + ) |
| 382 | + fallback_dtype = convert_dtype_to_mlx( |
| 383 | + scalar_op.o_type.dtype, auto_cast_unsupported=True |
| 384 | + ) |
| 385 | + return x.astype(fallback_dtype) |
| 386 | + else: |
| 387 | + raise e |
| 388 | + |
| 389 | + return cast |
| 390 | + |
| 391 | + |
| 392 | +@mlx_funcify_Elemwise_scalar_op.register(Sigmoid) |
| 393 | +def mlx_funcify_Elemwise_scalar_Sigmoid(scalar_op): |
| 394 | + return mx.sigmoid |
| 395 | + |
| 396 | + |
| 397 | +@mlx_funcify_Elemwise_scalar_op.register(Invert) |
| 398 | +def mlx_funcify_Elemwise_scalar_Invert(scalar_op): |
| 399 | + return mx.bitwise_invert |
| 400 | + |
| 401 | + |
| 402 | +@mlx_funcify_Elemwise_scalar_op.register(IsNan) |
| 403 | +def mlx_funcify_Elemwise_scalar_IsNan(scalar_op): |
| 404 | + return mx.isnan |
| 405 | + |
| 406 | + |
| 407 | +@mlx_funcify_Elemwise_scalar_op.register(IsInf) |
| 408 | +def mlx_funcify_Elemwise_scalar_IsInf(scalar_op): |
| 409 | + return mx.isinf |
| 410 | + |
| 411 | + |
| 412 | +@mlx_funcify_Elemwise_scalar_op.register(Erfc) |
| 413 | +def mlx_funcify_Elemwise_scalar_Erfc(scalar_op): |
| 414 | + def erfc(x): |
| 415 | + return 1.0 - mx.erf(x) |
| 416 | + |
| 417 | + return erfc |
| 418 | + |
| 419 | + |
| 420 | +@mlx_funcify_Elemwise_scalar_op.register(Erfcx) |
| 421 | +def mlx_funcify_Elemwise_scalar_Erfcx(scalar_op): |
| 422 | + def erfcx(x): |
| 423 | + return mx.exp(x * x) * (1.0 - mx.erf(x)) |
| 424 | + |
| 425 | + return erfcx |
| 426 | + |
| 427 | + |
| 428 | +@mlx_funcify_Elemwise_scalar_op.register(Softplus) |
| 429 | +def mlx_funcify_Elemwise_scalar_softplus(scalar_op): |
| 430 | + def softplus(x): |
| 431 | + # Numerically stable implementation of log(1 + exp(x)) |
| 432 | + # Following the same logic as the original PyTensor implementation |
| 433 | + return mx.where( |
| 434 | + x < -37.0, |
| 435 | + mx.exp(x), |
| 436 | + mx.where( |
| 437 | + x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x) |
| 438 | + ), |
| 439 | + ) |
| 440 | + |
| 441 | + return softplus |
| 442 | + |
| 443 | + |
| 444 | +@mlx_funcify.register(Elemwise) |
| 445 | +def mlx_funcify_Elemwise(op, node, **kwargs): |
| 446 | + return mlx_funcify_Elemwise_scalar_op(op.scalar_op) |
0 commit comments