Skip to content

Commit 447bdc8

Browse files
authored
feat: JLL changes for #788 (#789)
1 parent fdf21dc commit 447bdc8

File tree

2 files changed

+162
-116
lines changed

2 files changed

+162
-116
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 157 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@
7777
#include "shardy/dialect/sdy/ir/dialect.h"
7878
#include "shardy/integrations/c/attributes.h"
7979
#include "xla/pjrt/mlir_to_hlo.h"
80+
#include "xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h"
8081

8182
// IFRT
8283
#include "xla/python/ifrt/array.h"
84+
#include "xla/python/ifrt/attribute_map.h"
8385
#include "xla/python/ifrt/basic_device_list.h"
8486
#include "xla/python/ifrt/client.h"
8587
#include "xla/python/ifrt/compiler.h"
@@ -98,7 +100,6 @@
98100
#include "xla/python/ifrt/topology.h"
99101
#include "xla/python/ifrt/tuple.h"
100102
#include "xla/python/ifrt/value.h"
101-
#include "xla/python/ifrt/attribute_map.h"
102103

103104
// IFRT - PJRT
104105
#include "xla/python/pjrt_ifrt/pjrt_array.h"
@@ -179,10 +180,10 @@ extern "C" void (*ReactantThrowError)(const char *) = nullptr;
179180

180181
// Utilities for `StatusOr`.
181182
template <typename T> T MyValueOrThrow(absl::StatusOr<T> v) {
182-
if (!v.ok()) {
183-
ReactantThrowError(v.status().ToString().c_str());
184-
}
185-
return std::move(v).value();
183+
if (!v.ok()) {
184+
ReactantThrowError(v.status().ToString().c_str());
185+
}
186+
return std::move(v).value();
186187
}
187188

188189
extern "C" void ReactantHandleCuResult(uint32_t curesult) {
@@ -717,109 +718,6 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
717718
return wrap(res);
718719
}
719720

720-
// Sharding
721-
struct JLOpSharding {
722-
int32_t type;
723-
int32_t n_tile_dimensions;
724-
int64_t *tile_dimensions;
725-
int32_t n_layout_minor_to_major;
726-
int64_t *layout_minor_to_major;
727-
bool replicate_on_last_tile_dim;
728-
int32_t n_last_tile_dims;
729-
int32_t *last_tile_dims;
730-
int32_t n_tile_assignment_dimensions;
731-
int64_t *tile_assignment_dimensions;
732-
int32_t n_tile_assignment_devices;
733-
int64_t *tile_assignment_devices;
734-
int32_t n_iota_reshape_dims;
735-
int64_t *iota_reshape_dims;
736-
int32_t n_iota_transpose_perm;
737-
int32_t *iota_transpose_perm;
738-
bool is_shard_group;
739-
int64_t shard_group_id;
740-
int32_t shard_group_type;
741-
const void *op_sharding;
742-
};
743-
744-
void OpShardingToJLOpSharding(const xla::OpSharding op_sharding,
745-
JLOpSharding *jl_op_sharding) {
746-
jl_op_sharding->type = op_sharding.type();
747-
jl_op_sharding->replicate_on_last_tile_dim =
748-
op_sharding.replicate_on_last_tile_dim();
749-
750-
auto &shape = op_sharding.tile_shape();
751-
jl_op_sharding->n_tile_dimensions = shape.dimensions_size();
752-
std::vector<int64_t> dimensions(shape.dimensions().begin(),
753-
shape.dimensions().end());
754-
jl_op_sharding->tile_dimensions = new int64_t[dimensions.size()];
755-
std::copy(dimensions.begin(), dimensions.end(),
756-
jl_op_sharding->tile_dimensions);
757-
758-
if (shape.has_layout()) {
759-
auto &layout = shape.layout();
760-
jl_op_sharding->n_layout_minor_to_major = layout.minor_to_major_size();
761-
std::vector<int64_t> minor_to_major(layout.minor_to_major().begin(),
762-
layout.minor_to_major().end());
763-
jl_op_sharding->layout_minor_to_major = new int64_t[minor_to_major.size()];
764-
std::copy(minor_to_major.begin(), minor_to_major.end(),
765-
jl_op_sharding->layout_minor_to_major);
766-
} else {
767-
jl_op_sharding->n_layout_minor_to_major = 0;
768-
jl_op_sharding->layout_minor_to_major = nullptr;
769-
}
770-
771-
jl_op_sharding->n_last_tile_dims = op_sharding.last_tile_dims_size();
772-
std::vector<int> last_tile_dims(op_sharding.last_tile_dims().begin(),
773-
op_sharding.last_tile_dims().end());
774-
jl_op_sharding->last_tile_dims = new int[last_tile_dims.size()];
775-
std::copy(last_tile_dims.begin(), last_tile_dims.end(),
776-
jl_op_sharding->last_tile_dims);
777-
778-
jl_op_sharding->n_tile_assignment_dimensions =
779-
op_sharding.tile_assignment_dimensions_size();
780-
std::vector<int64_t> tile_assignment_dimensions(
781-
op_sharding.tile_assignment_dimensions().begin(),
782-
op_sharding.tile_assignment_dimensions().end());
783-
jl_op_sharding->tile_assignment_dimensions =
784-
new int64_t[tile_assignment_dimensions.size()];
785-
std::copy(tile_assignment_dimensions.begin(),
786-
tile_assignment_dimensions.end(),
787-
jl_op_sharding->tile_assignment_dimensions);
788-
789-
jl_op_sharding->n_tile_assignment_devices =
790-
op_sharding.tile_assignment_devices_size();
791-
std::vector<int64_t> tile_assignment_devices(
792-
op_sharding.tile_assignment_devices().begin(),
793-
op_sharding.tile_assignment_devices().end());
794-
jl_op_sharding->tile_assignment_devices =
795-
new int64_t[tile_assignment_devices.size()];
796-
std::copy(tile_assignment_devices.begin(), tile_assignment_devices.end(),
797-
jl_op_sharding->tile_assignment_devices);
798-
799-
jl_op_sharding->n_iota_reshape_dims = op_sharding.iota_reshape_dims_size();
800-
std::vector<int64_t> iota_reshape_dims(
801-
op_sharding.iota_reshape_dims().begin(),
802-
op_sharding.iota_reshape_dims().end());
803-
jl_op_sharding->iota_reshape_dims = new int64_t[iota_reshape_dims.size()];
804-
std::copy(iota_reshape_dims.begin(), iota_reshape_dims.end(),
805-
jl_op_sharding->iota_reshape_dims);
806-
807-
jl_op_sharding->n_iota_transpose_perm =
808-
op_sharding.iota_transpose_perm_size();
809-
std::vector<int> iota_transpose_perm(
810-
op_sharding.iota_transpose_perm().begin(),
811-
op_sharding.iota_transpose_perm().end());
812-
jl_op_sharding->iota_transpose_perm = new int[iota_transpose_perm.size()];
813-
std::copy(iota_transpose_perm.begin(), iota_transpose_perm.end(),
814-
jl_op_sharding->iota_transpose_perm);
815-
816-
jl_op_sharding->is_shard_group = op_sharding.is_shard_group();
817-
jl_op_sharding->shard_group_id = op_sharding.shard_group_id();
818-
jl_op_sharding->shard_group_type = op_sharding.shard_group_type();
819-
820-
jl_op_sharding->op_sharding = new xla::OpSharding(std::move(op_sharding));
821-
}
822-
823721
typedef PjRtFuture<> FutureType;
824722
extern "C" void FreeFuture(FutureType *Future) { delete Future; }
825723

@@ -903,7 +801,7 @@ ClientCompile(PjRtClient *client, MlirModule cmod, int64_t device_id,
903801

904802
extern "C" void
905803
PjRtLoadedExecutableGetOuputShardings(xla::PjRtLoadedExecutable *exec,
906-
JLOpSharding **jl_op_shardings,
804+
xla::OpSharding **op_shardings,
907805
int32_t num_op_shardings) {
908806
std::optional<std::vector<OpSharding>> shardings = exec->GetOutputShardings();
909807
if (!shardings.has_value()) {
@@ -920,13 +818,13 @@ PjRtLoadedExecutableGetOuputShardings(xla::PjRtLoadedExecutable *exec,
920818
}
921819

922820
for (int32_t i = 0; i < num_op_shardings; i++) {
923-
OpShardingToJLOpSharding(hlo_op_shardings[i], jl_op_shardings[i]);
821+
op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]);
924822
}
925823
}
926824

927825
extern "C" void
928826
PjRtLoadedExecutableGetParameterShardings(xla::PjRtLoadedExecutable *exec,
929-
JLOpSharding **jl_op_shardings,
827+
xla::OpSharding **op_shardings,
930828
int32_t num_op_shardings) {
931829
std::optional<std::vector<OpSharding>> shardings =
932830
exec->GetParameterShardings();
@@ -944,7 +842,7 @@ PjRtLoadedExecutableGetParameterShardings(xla::PjRtLoadedExecutable *exec,
944842
}
945843

946844
for (int32_t i = 0; i < num_op_shardings; i++) {
947-
OpShardingToJLOpSharding(hlo_op_shardings[i], jl_op_shardings[i]);
845+
op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]);
948846
}
949847
}
950848

@@ -1595,7 +1493,8 @@ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
15951493
return MyValueOrThrow(
15961494
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
15971495
address,
1598-
[](xla::ifrt::AttributeMap initialization_data) -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1496+
[](xla::ifrt::AttributeMap initialization_data)
1497+
-> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
15991498
auto pjrt_client =
16001499
std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU"));
16011500
return xla::ifrt::PjRtClient::Create(std::move(pjrt_client));
@@ -1789,12 +1688,139 @@ extern "C" bool ifrt_MemoryKindsAreEqual(ifrt::MemoryKind *a,
17891688

17901689
#pragma endregion
17911690

1792-
#pragma region HloSharding
1691+
#pragma region OpSharding
17931692

17941693
extern "C" void free_op_sharding(xla::OpSharding *op_sharding) {
17951694
delete op_sharding;
17961695
}
17971696

1697+
extern "C" int32_t
1698+
op_sharding_to_op_sharding_type(xla::OpSharding *op_sharding) {
1699+
return static_cast<int32_t>(op_sharding->type());
1700+
}
1701+
1702+
extern "C" int32_t
1703+
op_sharding_to_shard_group_type(xla::OpSharding *op_sharding) {
1704+
return static_cast<int32_t>(op_sharding->shard_group_type());
1705+
}
1706+
1707+
extern "C" int32_t op_sharding_to_shard_group_id(xla::OpSharding *op_sharding) {
1708+
return static_cast<int32_t>(op_sharding->shard_group_id());
1709+
}
1710+
1711+
extern "C" bool op_sharding_is_shard_group(xla::OpSharding *op_sharding) {
1712+
return op_sharding->is_shard_group();
1713+
}
1714+
1715+
extern "C" bool
1716+
op_sharding_replicate_on_last_tile_dim(xla::OpSharding *op_sharding) {
1717+
return op_sharding->replicate_on_last_tile_dim();
1718+
}
1719+
1720+
extern "C" bool op_sharding_has_last_tile_dims(xla::OpSharding *op_sharding) {
1721+
return op_sharding->last_tile_dims_size() > 0;
1722+
}
1723+
1724+
extern "C" int32_t
1725+
op_sharding_last_tile_dims_size(xla::OpSharding *op_sharding) {
1726+
return static_cast<int32_t>(op_sharding->last_tile_dims_size());
1727+
}
1728+
1729+
extern "C" void op_sharding_last_tile_dims(xla::OpSharding *op_sharding,
1730+
int32_t *last_tile_dims) {
1731+
std::vector<int32_t> last_tile_dims_vec(op_sharding->last_tile_dims().begin(),
1732+
op_sharding->last_tile_dims().end());
1733+
std::copy(last_tile_dims_vec.begin(), last_tile_dims_vec.end(),
1734+
last_tile_dims);
1735+
return;
1736+
}
1737+
1738+
extern "C" bool
1739+
op_sharding_has_iota_reshape_dims(xla::OpSharding *op_sharding) {
1740+
return op_sharding->iota_reshape_dims_size() > 0;
1741+
}
1742+
1743+
extern "C" int32_t
1744+
op_sharding_iota_reshape_dims_size(xla::OpSharding *op_sharding) {
1745+
return static_cast<int32_t>(op_sharding->iota_reshape_dims_size());
1746+
}
1747+
1748+
extern "C" void op_sharding_iota_reshape_dims(xla::OpSharding *op_sharding,
1749+
int32_t *iota_reshape_dims) {
1750+
std::vector<int32_t> iota_reshape_dims_vec(
1751+
op_sharding->iota_reshape_dims().begin(),
1752+
op_sharding->iota_reshape_dims().end());
1753+
std::copy(iota_reshape_dims_vec.begin(), iota_reshape_dims_vec.end(),
1754+
iota_reshape_dims);
1755+
return;
1756+
}
1757+
1758+
extern "C" bool
1759+
op_sharding_has_iota_transpose_perm(xla::OpSharding *op_sharding) {
1760+
return op_sharding->iota_transpose_perm_size() > 0;
1761+
}
1762+
1763+
extern "C" int32_t
1764+
op_sharding_iota_transpose_perm_size(xla::OpSharding *op_sharding) {
1765+
return static_cast<int32_t>(op_sharding->iota_transpose_perm_size());
1766+
}
1767+
1768+
extern "C" void op_sharding_iota_transpose_perm(xla::OpSharding *op_sharding,
1769+
int32_t *iota_transpose_perm) {
1770+
std::vector<int32_t> iota_transpose_perm_vec(
1771+
op_sharding->iota_transpose_perm().begin(),
1772+
op_sharding->iota_transpose_perm().end());
1773+
std::copy(iota_transpose_perm_vec.begin(), iota_transpose_perm_vec.end(),
1774+
iota_transpose_perm);
1775+
return;
1776+
}
1777+
1778+
extern "C" bool
1779+
op_sharding_has_tile_assignment_dimensions(xla::OpSharding *op_sharding) {
1780+
return op_sharding->tile_assignment_dimensions_size() > 0;
1781+
}
1782+
1783+
extern "C" int32_t
1784+
op_sharding_tile_assignment_dimensions_size(xla::OpSharding *op_sharding) {
1785+
return static_cast<int32_t>(op_sharding->tile_assignment_dimensions_size());
1786+
}
1787+
1788+
extern "C" void
1789+
op_sharding_tile_assignment_dimensions(xla::OpSharding *op_sharding,
1790+
int32_t *tile_assignment_dimensions) {
1791+
std::vector<int32_t> tile_assignment_dimensions_vec(
1792+
op_sharding->tile_assignment_dimensions().begin(),
1793+
op_sharding->tile_assignment_dimensions().end());
1794+
std::copy(tile_assignment_dimensions_vec.begin(),
1795+
tile_assignment_dimensions_vec.end(), tile_assignment_dimensions);
1796+
return;
1797+
}
1798+
1799+
extern "C" bool
1800+
op_sharding_has_tile_assignment_devices(xla::OpSharding *op_sharding) {
1801+
return op_sharding->tile_assignment_devices_size() > 0;
1802+
}
1803+
1804+
extern "C" int32_t
1805+
op_sharding_tile_assignment_devices_size(xla::OpSharding *op_sharding) {
1806+
return static_cast<int32_t>(op_sharding->tile_assignment_devices_size());
1807+
}
1808+
1809+
extern "C" void
1810+
op_sharding_tile_assignment_devices(xla::OpSharding *op_sharding,
1811+
int32_t *tile_assignment_devices) {
1812+
std::vector<int32_t> tile_assignment_devices_vec(
1813+
op_sharding->tile_assignment_devices().begin(),
1814+
op_sharding->tile_assignment_devices().end());
1815+
std::copy(tile_assignment_devices_vec.begin(),
1816+
tile_assignment_devices_vec.end(), tile_assignment_devices);
1817+
return;
1818+
}
1819+
1820+
#pragma endregion
1821+
1822+
#pragma region HloSharding
1823+
17981824
extern "C" void free_hlo_sharding(xla::HloSharding *hlo_sharding) {
17991825
delete hlo_sharding;
18001826
}
@@ -1914,3 +1940,20 @@ extern "C" void distributed_runtime_service_shutdown(
19141940
}
19151941

19161942
#pragma endregion
1943+
1944+
#pragma region Shardy
1945+
1946+
extern "C" xla::HloSharding *
1947+
hloShardingFromTensorShardingAttr(mlir::sdy::TensorShardingAttr attr,
1948+
mlir::sdy::MeshAttr meshAttr) {
1949+
mlir::ArrayRef<mlir::StringAttr> manual_axes = {};
1950+
std::function<mlir::sdy::MeshAttr(mlir::sdy::TensorShardingAttr)>
1951+
get_mesh_attr = [meshAttr](mlir::sdy::TensorShardingAttr local_attr) {
1952+
return meshAttr;
1953+
};
1954+
1955+
return new xla::HloSharding(
1956+
xla::sdy::convertToHloSharding(attr, get_mesh_attr, manual_axes));
1957+
}
1958+
1959+
#pragma endregion

deps/ReactantExtra/BUILD

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@ cc_library(
490490
"-Wl,-exported_symbol,_distributed_runtime_service_shutdown",
491491
"-Wl,-exported_symbol,_ClientGetDevices",
492492
"-Wl,-exported_symbol,_ClientGetAddressableDevices",
493+
"-Wl,-exported_symbol,_hloShardingFromTensorShardingAttr",
494+
"-Wl,-exported_symbol,_op_sharding_*",
493495
]}),
494496
deps = [
495497
"@enzyme//:EnzymeMLIR",
@@ -538,12 +540,13 @@ cc_library(
538540
"@xla//xla/pjrt/distributed:distributed",
539541
"@xla//xla/pjrt/distributed:client",
540542
"@xla//xla/pjrt/distributed:service",
543+
"@xla//xla/service/spmd/shardy/stablehlo_round_trip:export_shardings",
541544

542545
"@xla//xla:xla_proto_cc",
543546
"@xla//xla:xla_proto_cc_impl",
544547
"@xla//xla/stream_executor:device_description_proto_cc_impl",
545548

546-
"@xla//xla/tsl/platform/default:platform_port",
549+
"@xla//xla/tsl/platform/default:platform_port",
547550

548551
"@xla//xla/service:metrics_proto_cc",
549552
"@xla//xla/service:metrics_proto_cc_impl",
@@ -557,7 +560,7 @@ cc_library(
557560
"@xla//xla/service/cpu:cpu_transfer_manager",
558561
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
559562

560-
"@xla//xla/tsl/protobuf:protos_all_cc_impl",
563+
"@xla//xla/tsl/protobuf:protos_all_cc_impl",
561564
"@xla//xla/tsl/framework:allocator_registry_impl",
562565

563566
"@xla//xla/pjrt:status_casters",

0 commit comments

Comments
 (0)