Skip to content

Commit 2c70c13

Browse files
[xla:cpu][roll forward] Improve compilation time by not fusing large constants into LLVM modules
Fix for the breaking of a large model without thunks. Add tests to make sure this doesn't happen again. Reverts 1971267 PiperOrigin-RevId: 707524044
1 parent cea5792 commit 2c70c13

File tree

9 files changed

+298
-11
lines changed

9 files changed

+298
-11
lines changed

xla/service/cpu/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ xla_test(
456456
],
457457
tags = [
458458
"test_migrated_to_hlo_runner_pjrt",
459+
"test_xla_cpu_no_thunks",
459460
],
460461
deps = [
461462
"//xla/hlo/testlib:verified_hlo_module",
@@ -686,18 +687,39 @@ xla_cc_test(
686687
name = "ir_emitter_test",
687688
srcs = ["ir_emitter_test.cc"],
688689
deps = [
690+
":cpu_compiler",
691+
":cpu_executable",
692+
":cpu_options",
689693
":ir_emitter",
690694
":ir_function",
695+
":runtime_symbol_generator",
691696
":target_machine_features_stub",
697+
"//xla:cpu_function_runtime",
698+
"//xla/backends/cpu/codegen:cpu_features",
699+
"//xla/backends/cpu/codegen:ir_compiler",
700+
"//xla/backends/cpu/codegen:jit_compiler",
701+
"//xla/backends/cpu/codegen:target_machine_features",
692702
"//xla/hlo/analysis:hlo_ordering",
693703
"//xla/hlo/ir:hlo",
694704
"//xla/hlo/parser:hlo_parser",
705+
"//xla/hlo/transforms:hlo_memory_scheduler",
695706
"//xla/service:buffer_assignment",
707+
"//xla/service:buffer_value",
696708
"//xla/service:hlo_module_config",
697709
"//xla/service:logical_buffer",
710+
"//xla/service/llvm_ir:llvm_util",
698711
"//xla/tests:hlo_test_base",
712+
"//xla/tsl/lib/core:status_test_util",
713+
"@com_google_absl//absl/container:flat_hash_map",
714+
"@com_google_absl//absl/status:statusor",
715+
"@com_google_absl//absl/strings:string_view",
716+
"@com_google_googletest//:gtest",
699717
"@llvm-project//llvm:Core",
700718
"@llvm-project//llvm:Support",
719+
"@llvm-project//llvm:Target",
720+
"@llvm-project//mlir:IR",
721+
"@tsl//tsl/platform:env",
722+
"@tsl//tsl/platform:errors",
701723
"@tsl//tsl/platform:statusor",
702724
"@tsl//tsl/platform:test",
703725
"@tsl//tsl/platform:test_main",
@@ -742,6 +764,7 @@ cc_library(
742764
copts = tsl_copts(),
743765
deps = [
744766
":backend_config_proto_cc",
767+
":cpu_instruction_fusion",
745768
":cpu_options",
746769
":cpu_runtime",
747770
":dot_op_emitter",

xla/service/cpu/cpu_compiler.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,17 +1498,15 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
14981498
#endif
14991499
);
15001500

1501-
// Emit global variables for constants.
1502-
//
1503-
// TODO(ezhulenev): Figure out how to emit constants that are only needed for
1504-
// thread local computations as with Thunks runtime we keep constants outside
1505-
// of the LLVM module. Currently we end up doubling memory for constants.
1506-
TF_RETURN_IF_ERROR(nested_ir_emitter.EmitConstantGlobals());
15071501

15081502
// If we use Thunk runtime then instead of emitting LLVM function for the
15091503
// entry computation we emit a sequence of thunks that implement the
15101504
// computation as a sequence of interpreted commands.
15111505
if (module->config().debug_options().xla_cpu_use_thunk_runtime()) {
1506+
// The thunk runtime manages large constants, therefore we only emit
1507+
// small ones.
1508+
TF_RETURN_IF_ERROR(nested_ir_emitter.EmitSmallConstantGlobals());
1509+
15121510
// IR emitter is responsible for building LLVM module with host kernels for
15131511
// corresponding HLO instructions (fusions, elemental instructions, etc.).
15141512
IrEmitter2 ir_emitter2(*module, llvm_module.get(), &nested_ir_emitter);
@@ -1642,6 +1640,8 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
16421640
return with_hlo_proto(std::move(cpu_executable));
16431641
}
16441642

1643+
TF_RETURN_IF_ERROR(nested_ir_emitter.EmitAllConstantGlobals());
1644+
16451645
// Each computation is a single function. Emit all embedded computations
16461646
// before the entry computation. The order of computations returned from
16471647
// SubcomputationEmissionOrder guarantees that a called computation occurs
@@ -1899,7 +1899,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
18991899
// TODO(b/66051036): Run full msan for AOT.
19001900
/*emit_code_for_msan=*/false);
19011901

1902-
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
1902+
TF_RETURN_IF_ERROR(ir_emitter.EmitAllConstantGlobals());
19031903

19041904
for (ComputationToEmit subcomputation :
19051905
SubcomputationEmissionOrder(computation)) {

xla/service/cpu/cpu_compiler_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
See the License for the specific language governing permissions and
1111
limitations under the License.
1212
==============================================================================*/
13+
1314
#include <memory>
1415
#include <string>
1516
#include <utility>
@@ -55,6 +56,25 @@ TEST_F(CpuCompilerTest, RecordsStreamzStackTrace) {
5556
EXPECT_GT(it->second->points.size(), 0);
5657
}
5758

59+
TEST_F(CpuCompilerTest, CompilationWithLargeConstants) {
60+
absl::string_view module_string = R"(
61+
HloModule module
62+
63+
ENTRY main {
64+
a = f32[1000,1000]{1,0} parameter(0)
65+
b = f32[1000,1000]{1,0} constant({...})
66+
a_plus_b = f32[1000,1000]{1,0} add(a, b)
67+
c = f32[1000,1000]{1,0} constant({...})
68+
ROOT result = f32[1000,1000]{1,0} add(a_plus_b, c)
69+
}
70+
)";
71+
72+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
73+
ParseAndReturnVerifiedModule(module_string));
74+
75+
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/true));
76+
}
77+
5878
} // namespace
5979
} // namespace cpu
6080
} // namespace xla

xla/service/cpu/cpu_instruction_fusion.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ limitations under the License.
1919

2020
#include "absl/algorithm/container.h"
2121
#include "absl/log/log.h"
22+
#include "xla/hlo/ir/hlo_casting_utils.h"
23+
#include "xla/hlo/ir/hlo_instruction.h"
24+
#include "xla/hlo/ir/hlo_instructions.h"
2225
#include "xla/hlo/ir/hlo_opcode.h"
2326
#include "xla/service/fusion_node_indexing_evaluation.h"
2427
#include "xla/service/instruction_fusion.h"
@@ -81,6 +84,10 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
8184

8285
constexpr int kFusionThresholdBytes = 16 * 1024;
8386

87+
if (IsLargeConstant(producer)) {
88+
return FusionDecision::Forbid("Don't fuse large constants.");
89+
}
90+
8491
if (CanBeOutputFused(producer, consumer)) {
8592
VLOG(2) << "Fusion OK: Can create output fusion.";
8693
return FusionDecision::Allow();
@@ -219,5 +226,12 @@ HloInstruction* CpuInstructionFusion::FuseInstruction(
219226
evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
220227
return new_producer;
221228
}
229+
230+
bool CpuInstructionFusion::IsLargeConstant(
231+
const HloInstruction* constant) const {
232+
return constant->IsConstant() &&
233+
Cast<HloConstantInstruction>(constant)->literal().size_bytes() >
234+
GetLargeConstantThresholdBytes();
235+
}
222236
} // namespace cpu
223237
} // namespace xla

xla/service/cpu/cpu_instruction_fusion.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class CpuInstructionFusion : public InstructionFusion {
4343
return InstructionFusion::Run(module, execution_threads);
4444
}
4545

46+
// Returns the threshold for a constant to be considered a large constant.
47+
static constexpr int64_t GetLargeConstantThresholdBytes() {
48+
constexpr int64_t kLargeConstantThresholdBytes = 10000;
49+
return kLargeConstantThresholdBytes;
50+
}
51+
4652
protected:
4753
FusionDecision ShouldFuse(HloInstruction* consumer,
4854
int64_t operand_index) override;
@@ -53,6 +59,9 @@ class CpuInstructionFusion : public InstructionFusion {
5359
HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
5460
HloInstruction* producer) override;
5561

62+
// Returns if a constant is large enough to be considered a large constant.
63+
bool IsLargeConstant(const HloInstruction* constant) const;
64+
5665
// Keep track of the number of times each instruction inside a fusion node is
5766
// indexed with different index vectors.
5867
absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>

xla/service/cpu/cpu_instruction_fusion_test.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,5 +935,45 @@ ENTRY main {
935935
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
936936
}
937937

938+
TEST_F(OpcodeFusionTest, BigConstantNotInFusion) {
939+
absl::string_view module_string = R"(
940+
HloModule module
941+
942+
ENTRY main {
943+
a = f32[1000,1000]{1,0} parameter(0)
944+
b = f32[1000,1000]{1,0} constant({...})
945+
a_plus_b = f32[1000,1000]{1,0} add(a, b)
946+
c = f32[1000,1000]{1,0} constant({...})
947+
ROOT result = f32[1000,1000]{1,0} add(a_plus_b, c)
948+
}
949+
)";
950+
951+
TF_ASSERT_OK_AND_ASSIGN(auto module,
952+
ParseAndReturnVerifiedModule(module_string));
953+
RunFusionAndCheckOpcodesWereFused(
954+
module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
955+
HloOpcode::kParameter, HloOpcode::kAdd, HloOpcode::kAdd});
956+
}
957+
958+
TEST_F(OpcodeFusionTest, SmallConstantInFusion) {
959+
absl::string_view module_string = R"(
960+
HloModule module
961+
962+
ENTRY main {
963+
a = f32[10,10]{1,0} parameter(0)
964+
b = f32[10,10]{1,0} constant({...})
965+
a_plus_b = f32[10,10]{1,0} add(a, b)
966+
c = f32[10,10]{1,0} constant({...})
967+
ROOT result = f32[10,10]{1,0} add(a_plus_b, c)
968+
}
969+
)";
970+
971+
TF_ASSERT_OK_AND_ASSIGN(auto module,
972+
ParseAndReturnVerifiedModule(module_string));
973+
RunFusionAndCheckOpcodesWereFused(
974+
module.get(), {HloOpcode::kParameter, HloOpcode::kConstant,
975+
HloOpcode::kConstant, HloOpcode::kAdd, HloOpcode::kAdd});
976+
}
977+
938978
} // namespace
939979
} // namespace xla::cpu

xla/service/cpu/ir_emitter.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ limitations under the License.
6767
#include "xla/service/buffer_assignment.h"
6868
#include "xla/service/collective_ops_utils.h"
6969
#include "xla/service/cpu/backend_config.pb.h"
70+
#include "xla/service/cpu/cpu_instruction_fusion.h"
7071
#include "xla/service/cpu/cpu_options.h"
7172
#include "xla/service/cpu/cpu_runtime.h"
7273
#include "xla/service/cpu/dot_op_emitter.h"
@@ -330,9 +331,24 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
330331
return result_global;
331332
}
332333

333-
absl::Status IrEmitter::EmitConstantGlobals() {
334+
absl::Status IrEmitter::EmitSmallConstantGlobals() {
335+
return EmitConstantGlobals(/*max_size_bytes=*/CpuInstructionFusion::
336+
GetLargeConstantThresholdBytes());
337+
}
338+
339+
absl::Status IrEmitter::EmitAllConstantGlobals() {
340+
return EmitConstantGlobals(/*max_size_bytes=*/std::nullopt);
341+
}
342+
343+
absl::Status IrEmitter::EmitConstantGlobals(
344+
std::optional<size_t> max_size_bytes) {
334345
for (const BufferAllocation& allocation : assignment_.Allocations()) {
335-
if (!allocation.is_constant()) {
346+
// Large constants don't get fused with other instructions, so we don't
347+
// need to emit them as globals.
348+
if (!allocation.is_constant() ||
349+
(max_size_bytes &&
350+
llvm_ir::LiteralForConstantAllocation(allocation).size_bytes() >
351+
*max_size_bytes)) {
336352
continue;
337353
}
338354

xla/service/cpu/ir_emitter.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
177177
compute_function_.pop();
178178
}
179179

180-
// Emit an LLVM global variable for every constant buffer allocation.
181-
absl::Status EmitConstantGlobals();
180+
// Emit LLVM global variable for a small constant buffer allocation.
181+
absl::Status EmitSmallConstantGlobals();
182+
183+
// Emit LLVM global variables for all constant buffer allocations.
184+
absl::Status EmitAllConstantGlobals();
182185

183186
// Emits a call to a thread local function (e.g. to the computation nested
184187
// within a reduce or a map). Thread local callees (by definition) only write
@@ -239,6 +242,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
239242
protected:
240243
friend class IrEmitter2;
241244

245+
// Emit an LLVM global variable for every constant buffer allocation.
246+
absl::Status EmitConstantGlobals(std::optional<size_t> max_size_bytes);
247+
242248
//
243249
// The following methods implement the DfsHloVisitor interface.
244250
//

0 commit comments

Comments
 (0)