|
29 | 29 | "jax_visibility", |
30 | 30 | "mosaic_gpu_internal_users", |
31 | 31 | "mosaic_internal_users", |
| 32 | + "pallas_fuser_users", |
32 | 33 | "pallas_gpu_internal_users", |
33 | 34 | "pallas_tpu_internal_users", |
34 | 35 | "py_deps", |
@@ -105,6 +106,12 @@ package_group( |
105 | 106 | packages = pallas_tpu_internal_users, |
106 | 107 | ) |
107 | 108 |
|
| 109 | +package_group( |
| 110 | + name = "pallas_fuser_users", |
| 111 | + includes = [":internal"], |
| 112 | + packages = pallas_fuser_users, |
| 113 | +) |
| 114 | + |
108 | 115 | package_group( |
109 | 116 | name = "mosaic_gpu_users", |
110 | 117 | includes = [":internal"], |
@@ -628,6 +635,7 @@ pytype_strict_library( |
628 | 635 | "experimental/pallas/ops/gpu/**/*.py", |
629 | 636 | "experimental/pallas/ops/tpu/**/*.py", |
630 | 637 | "experimental/pallas/tpu.py", |
| 638 | + "experimental/pallas/fuser.py", |
631 | 639 | "experimental/pallas/triton.py", |
632 | 640 | ], |
633 | 641 | ), |
@@ -664,6 +672,21 @@ pytype_strict_library( |
664 | 672 | ], |
665 | 673 | ) |
666 | 674 |
|
| 675 | +pytype_strict_library( |
| 676 | + name = "pallas_fuser", |
| 677 | + srcs = ["experimental/pallas/fuser.py"], |
| 678 | + visibility = [ |
| 679 | + ":pallas_fuser_users", |
| 680 | + ], |
| 681 | + deps = [ |
| 682 | + ":pallas", # build_cleaner: keep |
| 683 | + "//jax/_src/pallas/fuser:block_spec", |
| 684 | + "//jax/_src/pallas/fuser:fusable", |
| 685 | + "//jax/_src/pallas/fuser:fusion", |
| 686 | + "//jax/_src/pallas/fuser:jaxpr_fusion", |
| 687 | + ], |
| 688 | +) |
| 689 | + |
667 | 690 | pytype_strict_library( |
668 | 691 | name = "pallas_gpu_ops", |
669 | 692 | srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"], |
|
0 commit comments