Skip to content

Commit 20a4825

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla] Add LiteralPool and LiteralCanonicalizer to share constant literals between HLO modules
This change saves a lot of host memory from duplicate constant literals in instantiated HLO modules. PiperOrigin-RevId: 707153215
1 parent 723f02b commit 20a4825

File tree

13 files changed

+498
-0
lines changed

13 files changed

+498
-0
lines changed

xla/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,33 @@ xla_cc_test(
629629
],
630630
)
631631

632+
cc_library(
633+
name = "literal_pool",
634+
srcs = ["literal_pool.cc"],
635+
hdrs = ["literal_pool.h"],
636+
visibility = ["//visibility:public"],
637+
deps = [
638+
":literal",
639+
":shape_util",
640+
"@com_google_absl//absl/base:core_headers",
641+
"@com_google_absl//absl/container:flat_hash_map",
642+
"@com_google_absl//absl/synchronization",
643+
"@tsl//tsl/platform:logging",
644+
],
645+
)
646+
647+
xla_cc_test(
648+
name = "literal_pool_test",
649+
srcs = ["literal_pool_test.cc"],
650+
deps = [
651+
":literal",
652+
":literal_pool",
653+
":literal_util",
654+
"@tsl//tsl/platform:test",
655+
"@tsl//tsl/platform:test_main",
656+
],
657+
)
658+
632659
cc_library(
633660
name = "literal_util",
634661
srcs = ["literal_util.cc"],

xla/hlo/ir/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cc_library(
6565
"//xla:array",
6666
"//xla:comparison_util",
6767
"//xla:literal",
68+
"//xla:literal_pool",
6869
"//xla:literal_util",
6970
"//xla:printer",
7071
"//xla:protobuf_util",

xla/hlo/ir/dfs_hlo_visitor_with_default.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
380380

381381
// Mark the computation as having changed.
382382
void MarkAsChanged() { changed_ = true; }
383+
void MarkAsMaybeChanged(bool changed) { changed_ |= changed; }
383384

384385
private:
385386
bool changed_ = false;

xla/hlo/ir/hlo_instructions.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ limitations under the License.
4040
#include "xla/hlo/ir/hlo_opcode.h"
4141
#include "xla/layout.h"
4242
#include "xla/literal.h"
43+
#include "xla/literal_pool.h"
4344
#include "xla/printer.h"
4445
#include "xla/service/hlo.pb.h"
4546
#include "xla/shape.h"
@@ -1343,6 +1344,18 @@ class HloConstantInstruction : public HloInstruction {
13431344
return hlo->opcode() == HloOpcode::kConstant;
13441345
}
13451346

1347+
// Canonicalize constant literal using the given literal pool.
1348+
bool Canonicalize(LiteralPool* literal_pool) {
1349+
if (literal_pool && literal_) {
1350+
auto canonical = literal_pool->GetCanonicalLiteral(literal_);
1351+
if (canonical != literal_) {
1352+
literal_ = std::move(canonical);
1353+
return true;
1354+
}
1355+
}
1356+
return false;
1357+
}
1358+
13461359
private:
13471360
bool IsElementwiseImpl(
13481361
const std::optional<int64_t>& operand_idx) const override;

xla/hlo/transforms/BUILD

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,40 @@ cc_library(
20262026
],
20272027
)
20282028

2029+
cc_library(
2030+
name = "literal_canonicalizer",
2031+
srcs = ["literal_canonicalizer.cc"],
2032+
hdrs = ["literal_canonicalizer.h"],
2033+
deps = [
2034+
"//xla:literal_pool",
2035+
"//xla/hlo/ir:hlo",
2036+
"//xla/hlo/pass:hlo_pass",
2037+
"//xla/hlo/pass:hlo_pass_pipeline",
2038+
"@com_google_absl//absl/container:flat_hash_set",
2039+
"@com_google_absl//absl/status",
2040+
"@com_google_absl//absl/status:statusor",
2041+
"@com_google_absl//absl/strings:string_view",
2042+
"@tsl//tsl/platform:errors",
2043+
"@tsl//tsl/platform:logging",
2044+
],
2045+
)
2046+
2047+
xla_cc_test(
2048+
name = "literal_canonicalizer_test",
2049+
srcs = ["literal_canonicalizer_test.cc"],
2050+
deps = [
2051+
":literal_canonicalizer",
2052+
"//xla:literal_pool",
2053+
"//xla/hlo/ir:hlo",
2054+
"//xla/hlo/parser:hlo_parser",
2055+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
2056+
"@com_google_absl//absl/strings:string_view",
2057+
"@tsl//tsl/platform:statusor",
2058+
"@tsl//tsl/platform:test",
2059+
"@tsl//tsl/platform:test_main",
2060+
],
2061+
)
2062+
20292063
cc_library(
20302064
name = "optimize_input_output_buffer_alias",
20312065
srcs = ["simplifiers/optimize_input_output_buffer_alias.cc"],
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/transforms/literal_canonicalizer.h"
17+
18+
#include <cstddef>
19+
20+
#include "absl/container/flat_hash_set.h"
21+
#include "absl/status/status.h"
22+
#include "absl/status/statusor.h"
23+
#include "absl/strings/string_view.h"
24+
#include "xla/hlo/ir/dfs_hlo_visitor.h"
25+
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
26+
#include "xla/hlo/ir/hlo_casting_utils.h"
27+
#include "xla/hlo/ir/hlo_instruction.h"
28+
#include "xla/hlo/ir/hlo_instructions.h"
29+
#include "xla/hlo/ir/hlo_module.h"
30+
#include "xla/literal_pool.h"
31+
#include "tsl/platform/errors.h"
32+
#include "tsl/platform/logging.h"
33+
34+
namespace xla {
35+
namespace {
36+
37+
class LiteralCanonicalizerVisitor : public DfsHloRewriteVisitor {
38+
public:
39+
LiteralCanonicalizerVisitor(LiteralPool* literal_pool, size_t min_size_bytes)
40+
: literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {}
41+
42+
absl::Status HandleConstant(HloInstruction* hlo) final {
43+
auto* constant = Cast<HloConstantInstruction>(hlo);
44+
if (constant->HasLiteral() &&
45+
constant->literal().size_bytes() >= min_size_bytes_) {
46+
MarkAsMaybeChanged(constant->Canonicalize(literal_pool_));
47+
}
48+
return absl::OkStatus();
49+
}
50+
51+
private:
52+
LiteralPool* literal_pool_;
53+
size_t min_size_bytes_;
54+
};
55+
56+
} // namespace
57+
58+
LiteralCanonicalizer::LiteralCanonicalizer(LiteralPool* literal_pool,
59+
size_t min_size_bytes)
60+
: literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {}
61+
62+
absl::StatusOr<bool> LiteralCanonicalizer::Run(
63+
HloModule* module,
64+
const absl::flat_hash_set<absl::string_view>& execution_threads) {
65+
// Every time we canonicalize literals in a module, we garbage collect expired
66+
// literals from the pool.
67+
size_t num_erased = literal_pool_->GarbageCollect();
68+
VLOG(3) << "Garbage collected " << num_erased << " expired literals";
69+
70+
LiteralCanonicalizerVisitor visitor(literal_pool_, min_size_bytes_);
71+
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&visitor));
72+
return visitor.changed();
73+
}
74+
75+
} // namespace xla
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_
17+
#define XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_
18+
19+
#include <cstddef>
20+
21+
#include "absl/container/flat_hash_set.h"
22+
#include "absl/status/statusor.h"
23+
#include "absl/strings/string_view.h"
24+
#include "xla/hlo/ir/hlo_module.h"
25+
#include "xla/hlo/pass/hlo_pass_interface.h"
26+
#include "xla/literal_pool.h"
27+
28+
namespace xla {
29+
30+
// Canonicalizes literals larger than 'min_size_bytes' in the HLO module using
31+
// the given literal pool.
32+
class LiteralCanonicalizer : public HloModulePass {
33+
public:
34+
LiteralCanonicalizer(LiteralPool* literal_pool, size_t min_size_bytes);
35+
36+
using HloPassInterface::Run;
37+
absl::StatusOr<bool> Run(
38+
HloModule* module,
39+
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
40+
41+
absl::string_view name() const override { return "literal-canonicalizer"; }
42+
43+
protected:
44+
LiteralPool* literal_pool_;
45+
size_t min_size_bytes_;
46+
};
47+
48+
} // namespace xla
49+
50+
#endif // XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/transforms/literal_canonicalizer.h"
17+
18+
#include "absl/strings/string_view.h"
19+
#include "xla/hlo/ir/hlo_casting_utils.h"
20+
#include "xla/hlo/ir/hlo_instructions.h"
21+
#include "xla/hlo/parser/hlo_parser.h"
22+
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
23+
#include "xla/literal_pool.h"
24+
#include "tsl/platform/statusor.h"
25+
#include "tsl/platform/test.h"
26+
27+
namespace xla {
28+
namespace {
29+
30+
class LiteralCanonicalizerTest : public HloHardwareIndependentTestBase {};
31+
32+
TEST_F(LiteralCanonicalizerTest, CanonicalizeConstants) {
33+
absl::string_view hlo_string = R"(
34+
HloModule m
35+
36+
ENTRY %entry {
37+
ROOT %c0 = f32[4] constant({1.0, 2.0, 3.0, 4.0})
38+
}
39+
)";
40+
41+
TF_ASSERT_OK_AND_ASSIGN(auto module0,
42+
ParseAndReturnVerifiedModule(hlo_string));
43+
TF_ASSERT_OK_AND_ASSIGN(auto module1,
44+
ParseAndReturnVerifiedModule(hlo_string));
45+
46+
LiteralPool literal_pool;
47+
LiteralCanonicalizer literal_canonicalizer(&literal_pool, 0);
48+
49+
EXPECT_FALSE(literal_canonicalizer.Run(module0.get()).value());
50+
EXPECT_TRUE(literal_canonicalizer.Run(module1.get()).value());
51+
52+
auto* c0 = Cast<HloConstantInstruction>(
53+
module0->entry_computation()->root_instruction());
54+
auto* c1 = Cast<HloConstantInstruction>(
55+
module1->entry_computation()->root_instruction());
56+
57+
EXPECT_EQ(c0->literal(), c1->literal());
58+
}
59+
60+
} // namespace
61+
} // namespace xla

0 commit comments

Comments
 (0)