7777#include " shardy/dialect/sdy/ir/dialect.h"
7878#include " shardy/integrations/c/attributes.h"
7979#include " xla/pjrt/mlir_to_hlo.h"
80+ #include " xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h"
8081
8182// IFRT
8283#include " xla/python/ifrt/array.h"
84+ #include " xla/python/ifrt/attribute_map.h"
8385#include " xla/python/ifrt/basic_device_list.h"
8486#include " xla/python/ifrt/client.h"
8587#include " xla/python/ifrt/compiler.h"
98100#include " xla/python/ifrt/topology.h"
99101#include " xla/python/ifrt/tuple.h"
100102#include " xla/python/ifrt/value.h"
101- #include " xla/python/ifrt/attribute_map.h"
102103
103104// IFRT - PJRT
104105#include " xla/python/pjrt_ifrt/pjrt_array.h"
@@ -179,10 +180,10 @@ extern "C" void (*ReactantThrowError)(const char *) = nullptr;
179180
180181// Utilities for `StatusOr`.
181182template <typename T> T MyValueOrThrow (absl::StatusOr<T> v) {
182- if (!v.ok ()) {
183- ReactantThrowError (v.status ().ToString ().c_str ());
184- }
185- return std::move (v).value ();
183+ if (!v.ok ()) {
184+ ReactantThrowError (v.status ().ToString ().c_str ());
185+ }
186+ return std::move (v).value ();
186187}
187188
188189extern " C" void ReactantHandleCuResult (uint32_t curesult) {
@@ -717,109 +718,6 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
717718 return wrap (res);
718719}
719720
720- // Sharding
721- struct JLOpSharding {
722- int32_t type;
723- int32_t n_tile_dimensions;
724- int64_t *tile_dimensions;
725- int32_t n_layout_minor_to_major;
726- int64_t *layout_minor_to_major;
727- bool replicate_on_last_tile_dim;
728- int32_t n_last_tile_dims;
729- int32_t *last_tile_dims;
730- int32_t n_tile_assignment_dimensions;
731- int64_t *tile_assignment_dimensions;
732- int32_t n_tile_assignment_devices;
733- int64_t *tile_assignment_devices;
734- int32_t n_iota_reshape_dims;
735- int64_t *iota_reshape_dims;
736- int32_t n_iota_transpose_perm;
737- int32_t *iota_transpose_perm;
738- bool is_shard_group;
739- int64_t shard_group_id;
740- int32_t shard_group_type;
741- const void *op_sharding;
742- };
743-
744- void OpShardingToJLOpSharding (const xla::OpSharding op_sharding,
745- JLOpSharding *jl_op_sharding) {
746- jl_op_sharding->type = op_sharding.type ();
747- jl_op_sharding->replicate_on_last_tile_dim =
748- op_sharding.replicate_on_last_tile_dim ();
749-
750- auto &shape = op_sharding.tile_shape ();
751- jl_op_sharding->n_tile_dimensions = shape.dimensions_size ();
752- std::vector<int64_t > dimensions (shape.dimensions ().begin (),
753- shape.dimensions ().end ());
754- jl_op_sharding->tile_dimensions = new int64_t [dimensions.size ()];
755- std::copy (dimensions.begin (), dimensions.end (),
756- jl_op_sharding->tile_dimensions );
757-
758- if (shape.has_layout ()) {
759- auto &layout = shape.layout ();
760- jl_op_sharding->n_layout_minor_to_major = layout.minor_to_major_size ();
761- std::vector<int64_t > minor_to_major (layout.minor_to_major ().begin (),
762- layout.minor_to_major ().end ());
763- jl_op_sharding->layout_minor_to_major = new int64_t [minor_to_major.size ()];
764- std::copy (minor_to_major.begin (), minor_to_major.end (),
765- jl_op_sharding->layout_minor_to_major );
766- } else {
767- jl_op_sharding->n_layout_minor_to_major = 0 ;
768- jl_op_sharding->layout_minor_to_major = nullptr ;
769- }
770-
771- jl_op_sharding->n_last_tile_dims = op_sharding.last_tile_dims_size ();
772- std::vector<int > last_tile_dims (op_sharding.last_tile_dims ().begin (),
773- op_sharding.last_tile_dims ().end ());
774- jl_op_sharding->last_tile_dims = new int [last_tile_dims.size ()];
775- std::copy (last_tile_dims.begin (), last_tile_dims.end (),
776- jl_op_sharding->last_tile_dims );
777-
778- jl_op_sharding->n_tile_assignment_dimensions =
779- op_sharding.tile_assignment_dimensions_size ();
780- std::vector<int64_t > tile_assignment_dimensions (
781- op_sharding.tile_assignment_dimensions ().begin (),
782- op_sharding.tile_assignment_dimensions ().end ());
783- jl_op_sharding->tile_assignment_dimensions =
784- new int64_t [tile_assignment_dimensions.size ()];
785- std::copy (tile_assignment_dimensions.begin (),
786- tile_assignment_dimensions.end (),
787- jl_op_sharding->tile_assignment_dimensions );
788-
789- jl_op_sharding->n_tile_assignment_devices =
790- op_sharding.tile_assignment_devices_size ();
791- std::vector<int64_t > tile_assignment_devices (
792- op_sharding.tile_assignment_devices ().begin (),
793- op_sharding.tile_assignment_devices ().end ());
794- jl_op_sharding->tile_assignment_devices =
795- new int64_t [tile_assignment_devices.size ()];
796- std::copy (tile_assignment_devices.begin (), tile_assignment_devices.end (),
797- jl_op_sharding->tile_assignment_devices );
798-
799- jl_op_sharding->n_iota_reshape_dims = op_sharding.iota_reshape_dims_size ();
800- std::vector<int64_t > iota_reshape_dims (
801- op_sharding.iota_reshape_dims ().begin (),
802- op_sharding.iota_reshape_dims ().end ());
803- jl_op_sharding->iota_reshape_dims = new int64_t [iota_reshape_dims.size ()];
804- std::copy (iota_reshape_dims.begin (), iota_reshape_dims.end (),
805- jl_op_sharding->iota_reshape_dims );
806-
807- jl_op_sharding->n_iota_transpose_perm =
808- op_sharding.iota_transpose_perm_size ();
809- std::vector<int > iota_transpose_perm (
810- op_sharding.iota_transpose_perm ().begin (),
811- op_sharding.iota_transpose_perm ().end ());
812- jl_op_sharding->iota_transpose_perm = new int [iota_transpose_perm.size ()];
813- std::copy (iota_transpose_perm.begin (), iota_transpose_perm.end (),
814- jl_op_sharding->iota_transpose_perm );
815-
816- jl_op_sharding->is_shard_group = op_sharding.is_shard_group ();
817- jl_op_sharding->shard_group_id = op_sharding.shard_group_id ();
818- jl_op_sharding->shard_group_type = op_sharding.shard_group_type ();
819-
820- jl_op_sharding->op_sharding = new xla::OpSharding (std::move (op_sharding));
821- }
822-
823721typedef PjRtFuture<> FutureType;
824722extern " C" void FreeFuture (FutureType *Future) { delete Future; }
825723
@@ -903,7 +801,7 @@ ClientCompile(PjRtClient *client, MlirModule cmod, int64_t device_id,
903801
904802extern " C" void
905803PjRtLoadedExecutableGetOuputShardings (xla::PjRtLoadedExecutable *exec,
906- JLOpSharding **jl_op_shardings ,
804+ xla::OpSharding **op_shardings ,
907805 int32_t num_op_shardings) {
908806 std::optional<std::vector<OpSharding>> shardings = exec->GetOutputShardings ();
909807 if (!shardings.has_value ()) {
@@ -920,13 +818,13 @@ PjRtLoadedExecutableGetOuputShardings(xla::PjRtLoadedExecutable *exec,
920818 }
921819
922820 for (int32_t i = 0 ; i < num_op_shardings; i++) {
923- OpShardingToJLOpSharding (hlo_op_shardings [i], jl_op_shardings [i]);
821+ op_shardings [i] = new xla::OpSharding (hlo_op_shardings [i]);
924822 }
925823}
926824
927825extern " C" void
928826PjRtLoadedExecutableGetParameterShardings (xla::PjRtLoadedExecutable *exec,
929- JLOpSharding **jl_op_shardings ,
827+ xla::OpSharding **op_shardings ,
930828 int32_t num_op_shardings) {
931829 std::optional<std::vector<OpSharding>> shardings =
932830 exec->GetParameterShardings ();
@@ -944,7 +842,7 @@ PjRtLoadedExecutableGetParameterShardings(xla::PjRtLoadedExecutable *exec,
944842 }
945843
946844 for (int32_t i = 0 ; i < num_op_shardings; i++) {
947- OpShardingToJLOpSharding (hlo_op_shardings [i], jl_op_shardings [i]);
845+ op_shardings [i] = new xla::OpSharding (hlo_op_shardings [i]);
948846 }
949847}
950848
@@ -1595,7 +1493,8 @@ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
15951493 return MyValueOrThrow (
15961494 xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
15971495 address,
1598- [](xla::ifrt::AttributeMap initialization_data) -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1496+ [](xla::ifrt::AttributeMap initialization_data)
1497+ -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
15991498 auto pjrt_client =
16001499 std::shared_ptr<xla::PjRtClient>(GetCApiClient (" TPU" ));
16011500 return xla::ifrt::PjRtClient::Create (std::move (pjrt_client));
@@ -1789,12 +1688,139 @@ extern "C" bool ifrt_MemoryKindsAreEqual(ifrt::MemoryKind *a,
17891688
17901689#pragma endregion
17911690
1792- #pragma region HloSharding
1691+ #pragma region OpSharding
17931692
17941693extern " C" void free_op_sharding (xla::OpSharding *op_sharding) {
17951694 delete op_sharding;
17961695}
17971696
1697+ extern " C" int32_t
1698+ op_sharding_to_op_sharding_type (xla::OpSharding *op_sharding) {
1699+ return static_cast <int32_t >(op_sharding->type ());
1700+ }
1701+
1702+ extern " C" int32_t
1703+ op_sharding_to_shard_group_type (xla::OpSharding *op_sharding) {
1704+ return static_cast <int32_t >(op_sharding->shard_group_type ());
1705+ }
1706+
1707+ extern " C" int32_t op_sharding_to_shard_group_id (xla::OpSharding *op_sharding) {
1708+ return static_cast <int32_t >(op_sharding->shard_group_id ());
1709+ }
1710+
1711+ extern " C" bool op_sharding_is_shard_group (xla::OpSharding *op_sharding) {
1712+ return op_sharding->is_shard_group ();
1713+ }
1714+
1715+ extern " C" bool
1716+ op_sharding_replicate_on_last_tile_dim (xla::OpSharding *op_sharding) {
1717+ return op_sharding->replicate_on_last_tile_dim ();
1718+ }
1719+
1720+ extern " C" bool op_sharding_has_last_tile_dims (xla::OpSharding *op_sharding) {
1721+ return op_sharding->last_tile_dims_size () > 0 ;
1722+ }
1723+
1724+ extern " C" int32_t
1725+ op_sharding_last_tile_dims_size (xla::OpSharding *op_sharding) {
1726+ return static_cast <int32_t >(op_sharding->last_tile_dims_size ());
1727+ }
1728+
1729+ extern " C" void op_sharding_last_tile_dims (xla::OpSharding *op_sharding,
1730+ int32_t *last_tile_dims) {
1731+ std::vector<int32_t > last_tile_dims_vec (op_sharding->last_tile_dims ().begin (),
1732+ op_sharding->last_tile_dims ().end ());
1733+ std::copy (last_tile_dims_vec.begin (), last_tile_dims_vec.end (),
1734+ last_tile_dims);
1735+ return ;
1736+ }
1737+
1738+ extern " C" bool
1739+ op_sharding_has_iota_reshape_dims (xla::OpSharding *op_sharding) {
1740+ return op_sharding->iota_reshape_dims_size () > 0 ;
1741+ }
1742+
1743+ extern " C" int32_t
1744+ op_sharding_iota_reshape_dims_size (xla::OpSharding *op_sharding) {
1745+ return static_cast <int32_t >(op_sharding->iota_reshape_dims_size ());
1746+ }
1747+
1748+ extern " C" void op_sharding_iota_reshape_dims (xla::OpSharding *op_sharding,
1749+ int32_t *iota_reshape_dims) {
1750+ std::vector<int32_t > iota_reshape_dims_vec (
1751+ op_sharding->iota_reshape_dims ().begin (),
1752+ op_sharding->iota_reshape_dims ().end ());
1753+ std::copy (iota_reshape_dims_vec.begin (), iota_reshape_dims_vec.end (),
1754+ iota_reshape_dims);
1755+ return ;
1756+ }
1757+
1758+ extern " C" bool
1759+ op_sharding_has_iota_transpose_perm (xla::OpSharding *op_sharding) {
1760+ return op_sharding->iota_transpose_perm_size () > 0 ;
1761+ }
1762+
1763+ extern " C" int32_t
1764+ op_sharding_iota_transpose_perm_size (xla::OpSharding *op_sharding) {
1765+ return static_cast <int32_t >(op_sharding->iota_transpose_perm_size ());
1766+ }
1767+
1768+ extern " C" void op_sharding_iota_transpose_perm (xla::OpSharding *op_sharding,
1769+ int32_t *iota_transpose_perm) {
1770+ std::vector<int32_t > iota_transpose_perm_vec (
1771+ op_sharding->iota_transpose_perm ().begin (),
1772+ op_sharding->iota_transpose_perm ().end ());
1773+ std::copy (iota_transpose_perm_vec.begin (), iota_transpose_perm_vec.end (),
1774+ iota_transpose_perm);
1775+ return ;
1776+ }
1777+
1778+ extern " C" bool
1779+ op_sharding_has_tile_assignment_dimensions (xla::OpSharding *op_sharding) {
1780+ return op_sharding->tile_assignment_dimensions_size () > 0 ;
1781+ }
1782+
1783+ extern " C" int32_t
1784+ op_sharding_tile_assignment_dimensions_size (xla::OpSharding *op_sharding) {
1785+ return static_cast <int32_t >(op_sharding->tile_assignment_dimensions_size ());
1786+ }
1787+
1788+ extern " C" void
1789+ op_sharding_tile_assignment_dimensions (xla::OpSharding *op_sharding,
1790+ int32_t *tile_assignment_dimensions) {
1791+ std::vector<int32_t > tile_assignment_dimensions_vec (
1792+ op_sharding->tile_assignment_dimensions ().begin (),
1793+ op_sharding->tile_assignment_dimensions ().end ());
1794+ std::copy (tile_assignment_dimensions_vec.begin (),
1795+ tile_assignment_dimensions_vec.end (), tile_assignment_dimensions);
1796+ return ;
1797+ }
1798+
1799+ extern " C" bool
1800+ op_sharding_has_tile_assignment_devices (xla::OpSharding *op_sharding) {
1801+ return op_sharding->tile_assignment_devices_size () > 0 ;
1802+ }
1803+
1804+ extern " C" int32_t
1805+ op_sharding_tile_assignment_devices_size (xla::OpSharding *op_sharding) {
1806+ return static_cast <int32_t >(op_sharding->tile_assignment_devices_size ());
1807+ }
1808+
1809+ extern " C" void
1810+ op_sharding_tile_assignment_devices (xla::OpSharding *op_sharding,
1811+ int32_t *tile_assignment_devices) {
1812+ std::vector<int32_t > tile_assignment_devices_vec (
1813+ op_sharding->tile_assignment_devices ().begin (),
1814+ op_sharding->tile_assignment_devices ().end ());
1815+ std::copy (tile_assignment_devices_vec.begin (),
1816+ tile_assignment_devices_vec.end (), tile_assignment_devices);
1817+ return ;
1818+ }
1819+
1820+ #pragma endregion
1821+
1822+ #pragma region HloSharding
1823+
17981824extern " C" void free_hlo_sharding (xla::HloSharding *hlo_sharding) {
17991825 delete hlo_sharding;
18001826}
@@ -1914,3 +1940,20 @@ extern "C" void distributed_runtime_service_shutdown(
19141940}
19151941
19161942#pragma endregion
1943+
1944+ #pragma region Shardy
1945+
1946+ extern " C" xla::HloSharding *
1947+ hloShardingFromTensorShardingAttr (mlir::sdy::TensorShardingAttr attr,
1948+ mlir::sdy::MeshAttr meshAttr) {
1949+ mlir::ArrayRef<mlir::StringAttr> manual_axes = {};
1950+ std::function<mlir::sdy::MeshAttr (mlir::sdy::TensorShardingAttr)>
1951+ get_mesh_attr = [meshAttr](mlir::sdy::TensorShardingAttr local_attr) {
1952+ return meshAttr;
1953+ };
1954+
1955+ return new xla::HloSharding (
1956+ xla::sdy::convertToHloSharding (attr, get_mesh_attr, manual_axes));
1957+ }
1958+
1959+ #pragma endregion
0 commit comments