Skip to content

Commit 9338ffd

Browse files
KanishAnandGoogle-ML-Automation
authored andcommitted
Update mesh definition to better match it's use cases of querying tile index from device id's or vice-versa. Refactor into separate classes.
#hloshardingv3 PiperOrigin-RevId: 820154911
1 parent 1cd9f9e commit 9338ffd

File tree

4 files changed

+134
-46
lines changed

4 files changed

+134
-46
lines changed

xla/hlo/ir/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ cc_library(
6060
deps = [
6161
":backend_config",
6262
":hlo_sharding",
63+
":named_sharding",
6364
":ptrvec",
6465
":tile_assignment",
6566
"//xla:array",
@@ -177,12 +178,31 @@ xla_cc_test(
177178
],
178179
)
179180

181+
cc_library(
182+
name = "mesh_and_axis",
183+
hdrs = ["mesh_and_axis.h"],
184+
deps = [
185+
":tile_assignment",
186+
"@com_google_absl//absl/strings",
187+
],
188+
)
189+
190+
cc_library(
191+
name = "named_sharding",
192+
hdrs = ["named_sharding.h"],
193+
deps = [
194+
":mesh_and_axis",
195+
"//xla:xla_data_proto_cc",
196+
],
197+
)
198+
180199
cc_library(
181200
name = "hlo_sharding",
182201
srcs = ["hlo_sharding.cc"],
183202
hdrs = ["hlo_sharding.h"],
184203
deps = [
185204
":hlo_op_metadata",
205+
":named_sharding",
186206
":tile_assignment",
187207
"//xla:array",
188208
"//xla:printer",

xla/hlo/ir/hlo_sharding.h

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "absl/strings/string_view.h"
3636
#include "absl/types/span.h"
3737
#include "xla/array.h"
38+
#include "xla/hlo/ir/named_sharding.h"
3839
#include "xla/hlo/ir/tile_assignment.h" // IWYU pragma: export
3940
#include "xla/printer.h"
4041
#include "xla/shape.h"
@@ -53,52 +54,6 @@ class HloSharding {
5354
static inline constexpr absl::string_view kShardingFrontendAttrName =
5455
"xla.sdy.sharding";
5556

56-
// C++ representation for corresponding proto types in `xla_data.proto` so
57-
// same documentation applies, except AxisRef elements are pointers to
58-
// `MeshAxis` elements instead of indices.
59-
//
60-
// TODO(b/449783607): Move mesh, axis to mesh_and_axis.h and move
61-
// NamedSharding out of HloSharding to match proto after change to using
62-
// mesh_and_axis.h. Currently simply moving this out will cause name
63-
// clashes with proto as they both use same xla namespace.
64-
struct MeshAxis {
65-
std::string name;
66-
int64_t size;
67-
};
68-
69-
struct Mesh {
70-
std::vector<MeshAxis> axes;
71-
std::vector<int64_t> device_ids;
72-
};
73-
74-
struct AxisRef {
75-
struct SubAxis {
76-
int64_t pre_size;
77-
int64_t size;
78-
};
79-
80-
const MeshAxis* axis;
81-
std::optional<SubAxis> sub_axis_info;
82-
};
83-
84-
// C++ representation for corresponding `OpSharding::NamedSharding` proto.
85-
//
86-
// TODO(b/450770542): Add corresponding IFTTT in attrs.td
87-
class NamedSharding {
88-
struct DimensionSharding {
89-
std::vector<AxisRef> axes;
90-
bool is_closed;
91-
};
92-
93-
std::vector<NamedSharding> tuple_shardings_;
94-
95-
Mesh mesh_;
96-
std::vector<DimensionSharding> dim_shardings_;
97-
std::vector<AxisRef> replicated_axes_;
98-
std::vector<AxisRef> unreduced_axes_;
99-
std::vector<OpMetadata> metadata_;
100-
};
101-
10257
// Creates a trivial sharding that replicates a maximal tile across all
10358
// devices.
10459
static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) {

xla/hlo/ir/mesh_and_axis.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_HLO_IR_MESH_AND_AXIS_H_
17+
#define XLA_HLO_IR_MESH_AND_AXIS_H_
18+
19+
#include <cstdint>
20+
#include <optional>
21+
#include <string>
22+
#include <vector>
23+
24+
#include "xla/hlo/ir/tile_assignment.h"
25+
26+
namespace xla {
27+
28+
// C++ representation for corresponding `OpSharding::Mesh` proto so same
29+
// documentation applies, except device assignment is represented in the array
30+
// format instead of list of device ids to align with various array specific
31+
// queries. Note that `TileAssignment` is used instead of `xla::Array` for
32+
// optimized array representation in iota based cases which is the most common
33+
// case.
34+
//
35+
// Example: device_assignment {{3, 0, 2}, {1, 4, 5}} with axes names
36+
// {"data", "model"} represents a 2 * 3 mesh of 6 devices, with "data" axis of
37+
// size 2 and "model" axis of size 3.
38+
class Mesh {
39+
private:
40+
// Dimensions of the `device_assignment_` array correspond to the axes of the
41+
// mesh.
42+
TileAssignment device_assignment_;
43+
// Axes names correspond to names of axes represented by dimensions of
44+
// `device_assignment_`. Size of `axes_names_` should be equal to the number
45+
// of dimensions in the device_assignment_.
46+
std::vector<std::string> axes_names_;
47+
};
48+
49+
// C++ representation for corresponding `OpSharding::AxisRef`proto so same
50+
// documentation applies.
51+
class AxisRef {
52+
private:
53+
struct SubAxis {
54+
int64_t pre_size;
55+
int64_t size;
56+
};
57+
58+
// Index corresponding to axis in the mesh. It should be a valid index into
59+
// `mesh.axes_names_`.
60+
int64_t mesh_axis_index_;
61+
std::optional<SubAxis> sub_axis_info_;
62+
};
63+
64+
} // namespace xla
65+
66+
#endif // XLA_HLO_IR_MESH_AND_AXIS_H_

xla/hlo/ir/named_sharding.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_HLO_IR_NAMED_SHARDING_H_
17+
#define XLA_HLO_IR_NAMED_SHARDING_H_
18+
19+
#include <vector>
20+
21+
#include "xla/hlo/ir/mesh_and_axis.h"
22+
#include "xla/xla_data.pb.h"
23+
24+
namespace xla {
25+
26+
// C++ representation for corresponding `OpSharding::NamedSharding` proto so
27+
// same documentation applies.
28+
//
29+
// TODO(b/450770542): Add corresponding IFTTT in attrs.td
30+
class NamedSharding {
31+
struct DimensionSharding {
32+
std::vector<AxisRef> axes;
33+
bool is_closed;
34+
};
35+
36+
std::vector<NamedSharding> tuple_shardings_;
37+
38+
Mesh mesh_;
39+
std::vector<DimensionSharding> dim_shardings_;
40+
std::vector<AxisRef> replicated_axes_;
41+
std::vector<AxisRef> unreduced_axes_;
42+
std::vector<OpMetadata> metadata_;
43+
};
44+
45+
} // namespace xla
46+
47+
#endif // XLA_HLO_IR_NAMED_SHARDING_H_

0 commit comments

Comments
 (0)