|
1 | 1 | from functools import partial
|
2 | 2 | from collections.abc import Callable
|
3 | 3 | from enum import auto, Enum
|
| 4 | +from collections.abc import Sequence |
4 | 5 | from looseversion import LooseVersion
|
5 | 6 |
|
6 | 7 | from thunder.torch import torchsymbol, TensorLike, register_function
|
@@ -371,6 +372,100 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
|
371 | 372 | )
|
372 | 373 |
|
373 | 374 |
|
| 375 | +if torch.distributed.is_available(): |
| 376 | + from torch.distributed.tensor import DTensor |
| 377 | + from torch.distributed.tensor.placement_types import Placement, DeviceMesh |
| 378 | + |
| 379 | + def dtensor_from_local_meta( |
| 380 | + x, |
| 381 | + mesh, |
| 382 | + placements, |
| 383 | + *, |
| 384 | + run_check: bool = False, |
| 385 | + shape: torch.Size | None = None, |
| 386 | + stride: tuple[int, ...] | None = None, |
| 387 | + ): |
| 388 | + res = run_with_fake_tensor( |
| 389 | + DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride |
| 390 | + ) |
| 391 | + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor |
| 392 | + |
| 393 | + res = proxify_dtensor(res) |
| 394 | + return res |
| 395 | + |
| 396 | + dtensor_from_local_prim = make_prim("dtensor_from_local", "dtensor_from_local", meta=dtensor_from_local_meta) |
| 397 | + |
| 398 | + dtensor_from_local_prim_impl = pytorchex.register_operator( |
| 399 | + "dtensor_from_local", like=dtensor_from_local_prim, fn=DTensor.from_local |
| 400 | + ) |
| 401 | + |
| 402 | + pytorchex.register_implementation(dtensor_from_local_prim, dtensor_from_local_prim_impl) |
| 403 | + |
| 404 | + @dtensor_torchsymbol(DTensor.from_local, id="dtensor.torch.from_local") |
| 405 | + def dtensor_from_local( |
| 406 | + x, |
| 407 | + mesh, |
| 408 | + placements, |
| 409 | + *, |
| 410 | + run_check: bool = False, |
| 411 | + shape: torch.Size | None = None, |
| 412 | + stride: tuple[int, ...] | None = None, |
| 413 | + ) -> DTensorProxy | None: |
| 414 | + return dtensor_from_local_prim(x, mesh, placements, run_check=run_check, shape=shape, stride=stride) |
| 415 | + |
| 416 | + def dtensor_redistribute_meta( |
| 417 | + dtensor, |
| 418 | + device_mesh: DeviceMesh | None = None, |
| 419 | + placements: Sequence[Placement] | None = None, |
| 420 | + *, |
| 421 | + async_op: bool = False, |
| 422 | + ) -> DTensorProxy | None: |
| 423 | + res = run_with_fake_tensor(DTensor.redistribute, dtensor, device_mesh, placements, async_op=async_op) |
| 424 | + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor |
| 425 | + |
| 426 | + res = proxify_dtensor(res) |
| 427 | + return res |
| 428 | + |
| 429 | + dtensor_redistribute_prim = make_prim( |
| 430 | + "dtensor_redistribute", "dtensor_redistribute", meta=dtensor_redistribute_meta |
| 431 | + ) |
| 432 | + |
| 433 | + dtensor_redistribute_prim_impl = pytorchex.register_operator( |
| 434 | + "dtensor_redistribute", like=dtensor_redistribute_prim, fn=DTensor.redistribute |
| 435 | + ) |
| 436 | + |
| 437 | + @dtensor_torchsymbol(DTensor.redistribute, id="dtensor.torch.redistribute") |
| 438 | + def dtensor_redistribute( |
| 439 | + dtensor, |
| 440 | + device_mesh: DeviceMesh | None = None, |
| 441 | + placements: Sequence[Placement] | None = None, |
| 442 | + *, |
| 443 | + async_op: bool = False, |
| 444 | + ) -> DTensorProxy | None: |
| 445 | + return dtensor_redistribute_prim(dtensor, device_mesh, placements, async_op=async_op) |
| 446 | + |
| 447 | + pytorchex.register_implementation(dtensor_redistribute_prim, dtensor_redistribute_prim_impl) |
| 448 | + |
| 449 | + def dtensor_to_local_meta(dtensor, *, grad_placements: Sequence[Placement] | None = None): |
| 450 | + res = run_with_fake_tensor(DTensor.to_local, dtensor, grad_placements=grad_placements) |
| 451 | + from thunder.core.proxies import proxy |
| 452 | + |
| 453 | + res = proxy(res) |
| 454 | + return res |
| 455 | + |
| 456 | + dtensor_to_local_prim = make_prim("dtensor_to_local", "dtensor_to_local", meta=dtensor_to_local_meta) |
| 457 | + |
| 458 | + dtensor_to_local_prim_impl = pytorchex.register_operator( |
| 459 | + "dtensor_to_local", like=dtensor_to_local_prim, fn=DTensor.to_local |
| 460 | + ) |
| 461 | + |
| 462 | + pytorchex.register_implementation(dtensor_to_local_prim, dtensor_to_local_prim_impl) |
| 463 | + |
| 464 | + @dtensor_torchsymbol(DTensor.to_local, id="dtensor.torch.to_local") |
| 465 | + def dtensor_to_local(dtensor, *, grad_placements: Sequence[Placement] | None = None) -> DTensorProxy | None: |
| 466 | + return dtensor_to_local_prim(dtensor, grad_placements=grad_placements) |
| 467 | + |
| 468 | + |
374 | 469 | expand = partial(expand_impl, broadcast_prim=dtensor_broadcast_in_dim_prim)
|
375 | 470 | maybe_broadcast = partial(maybe_broadcast_impl, expand_fn=expand)
|
376 | 471 |
|
|
0 commit comments