|
25 | 25 | MLX_DTYPES = {
|
26 | 26 | "float16": mx.float16,
|
27 | 27 | "float32": mx.float32,
|
28 |
| - "float64": None, # mlx only supports float64 on cpu |
| 28 | + "float64": mx.float64, # for mlx float64 only supported on cpu |
29 | 29 | "uint8": mx.uint8,
|
30 | 30 | "uint16": mx.uint16,
|
31 | 31 | "uint32": mx.uint32,
|
@@ -104,19 +104,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
|
104 | 104 | return mx.array(x, dtype=mlx_dtype)
|
105 | 105 |
|
106 | 106 | if isinstance(x, list):
|
107 |
| - |
108 |
| - def to_scalar_list(x): |
109 |
| - if isinstance(x, list): |
110 |
| - return [to_scalar_list(xi) for xi in x] |
111 |
| - elif isinstance(x, mx.array): |
112 |
| - if x.ndim == 0: |
113 |
| - return x.item() |
114 |
| - else: |
115 |
| - return x.tolist() |
116 |
| - else: |
117 |
| - return x |
118 |
| - |
119 |
| - return mx.array(to_scalar_list(x), dtype=mlx_dtype) |
| 107 | + return mx.array(x, dtype=mlx_dtype) |
120 | 108 |
|
121 | 109 | if _is_h5py_dataset(x):
|
122 | 110 | if h5py is None:
|
@@ -592,3 +580,41 @@ def __init__(self, fun):
|
592 | 580 | def __call__(self, *args, **kwargs):
|
593 | 581 | outputs, _ = self.fun(*args, **kwargs)
|
594 | 582 | return outputs
|
| 583 | + |
| 584 | + |
| 585 | +def enable_float64(): |
| 586 | + """Returns context manager forcing operations on cpu |
| 587 | +
|
| 588 | + MLX requires operations involving `float64` to be on cpu, |
| 589 | + mimicking jax's `enable_x64()` |
| 590 | +
|
| 591 | + Usage: |
| 592 | + ``` |
| 593 | + a = mx.array([1, 2, 3], dtype=mx.float64) |
| 594 | + b = mx.array([4, 5, 6], dtype=mx.float64) |
| 595 | +
|
| 596 | + with enable_float64(): |
| 597 | + c = mx.add(a, b) |
| 598 | +
|
| 599 | + # OR |
| 600 | + mlx_cpu_context = mx.stream(mx.cpu) |
| 601 | + mlx_cpu_context.__enter__() |
| 602 | + c = mx.add(a, b) |
| 603 | + mlx_cpu_context.__exit__(None, None, None) |
| 604 | + ``` |
| 605 | + """ |
| 606 | + return mx.stream(mx.cpu) |
| 607 | + |
| 608 | + |
| 609 | +def device_scope(device_name): |
| 610 | + if isinstance(device_name, str): |
| 611 | + mlx_device = mx.cpu if "cpu" in device_name.lower() else mx.gpu |
| 612 | + elif not isinstance(device_name, mx.Device): |
| 613 | + raise ValueError( |
| 614 | + "Invalid value for argument `device_name`. " |
| 615 | + "Expected a string like 'gpu:0' or a `mlx.core.Device` instance. " |
| 616 | + f"Received: device_name='{device_name}'" |
| 617 | + ) |
| 618 | + else: |
| 619 | + mlx_device = device_name |
| 620 | + return mx.stream(mlx_device) |
0 commit comments