Skip to content

Commit 5bb8b2f

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Remove duplicated thunks execution logic.
PiperOrigin-RevId: 853008488
1 parent 9cec995 commit 5bb8b2f

File tree

1 file changed

+69
-140
lines changed

1 file changed

+69
-140
lines changed

xla/pjrt/cpu/cpu_client.cc

Lines changed: 69 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
15071507
run_options.set_device_assignment(device_assignment.get());
15081508
run_options.set_intra_op_thread_pool(client_->eigen_intraop_device());
15091509

1510-
auto cpu_run_options = std::make_shared<cpu::CpuExecutableRunOptions>();
1510+
auto cpu_run_options = std::make_unique<cpu::CpuExecutableRunOptions>();
15111511
run_options.set_cpu_executable_run_options(cpu_run_options.get());
15121512

15131513
const CpuExecuteContext* cpu_execute_context =
@@ -1567,83 +1567,87 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
15671567
execute_inline = true;
15681568
}
15691569

1570-
absl::Status inline_result_status;
1571-
if (input_deps.events().empty() && execute_inline) {
1572-
// Synchronously call generated function or thunk sequence.
1573-
1570+
auto execute_thunks = [cpu_executable, buffer_table = std::move(buffer_table),
1571+
eigen_device = client()->eigen_intraop_device(),
1572+
run_options = std::move(run_options)]()
1573+
-> absl::StatusOr<tsl::AsyncValueRef<cpu::Thunk::ExecuteEvent>> {
15741574
// Set denormal and rounding behavior to match the default TF
15751575
// ThreadPool behavior.
15761576
tsl::port::ScopedFlushDenormal flush;
15771577
tsl::port::ScopedSetRound round(FE_TONEAREST);
15781578

1579-
// Execution status for XLA:CPU thunk runtime.
1580-
tsl::AsyncValueRef<cpu::Thunk::ExecuteEvent> thunks_execute_event;
1581-
15821579
// Immediately allocate memory and prepare for computation.
1583-
buffer_alloc.Allocate(*client()->allocator());
1584-
buffer_alloc_and_copy.AllocateAndCopy(*client()->allocator());
15851580
for (const auto& buffer_info : buffer_table) {
15861581
CHECK(buffer_info.buffer.IsAvailable());
15871582
if (buffer_info.buffer.IsError()) {
15881583
return buffer_info.buffer.GetError();
15891584
}
15901585
}
15911586

1592-
if (cpu_executable->has_thunks()) {
1593-
// Call interpreted thunk sequence implementing XLA executable.
1594-
absl::InlinedVector<MaybeOwningDeviceAddress, 8> buffer_device_mem;
1595-
buffer_device_mem.reserve(buffer_table.size());
1596-
for (const auto& buffer_info : buffer_table) {
1597-
buffer_device_mem.emplace_back(
1598-
se::DeviceAddressBase(buffer_info.buffer->untyped_data(),
1599-
buffer_info.buffer->size_bytes()));
1600-
}
1601-
1602-
cpu::BufferAllocations allocations(buffer_device_mem);
1587+
if (!cpu_executable->has_thunks()) {
1588+
return Internal("CpuExecutable has no thunks.");
1589+
}
1590+
// Call interpreted thunk sequence implementing XLA executable.
1591+
absl::InlinedVector<MaybeOwningDeviceAddress, 8> buffer_device_mem;
1592+
buffer_device_mem.reserve(buffer_table.size());
1593+
for (const auto& buffer_info : buffer_table) {
1594+
buffer_device_mem.emplace_back(
1595+
se::DeviceAddressBase(buffer_info.buffer->untyped_data(),
1596+
buffer_info.buffer->size_bytes()));
1597+
}
16031598

1604-
TF_ASSIGN_OR_RETURN(
1605-
cpu::Thunk::CollectiveExecuteParams collective_params,
1606-
cpu::Thunk::CollectiveExecuteParams::Create(&run_options));
1599+
cpu::BufferAllocations allocations(buffer_device_mem);
16071600

1608-
TF_ASSIGN_OR_RETURN(
1609-
cpu::Thunk::CustomCallExecuteParams custom_call_execute_params,
1610-
cpu::Thunk::CustomCallExecuteParams::Create(&run_options));
1601+
TF_ASSIGN_OR_RETURN(
1602+
cpu::Thunk::CollectiveExecuteParams collective_params,
1603+
cpu::Thunk::CollectiveExecuteParams::Create(&run_options));
16111604

1612-
std::optional<cpu::Thunk::YnnParams> ynn_params;
1613-
if (cpu_executable->has_ynn_fusions()) {
1614-
TF_ASSIGN_OR_RETURN(ynn_params,
1615-
cpu::Thunk::YnnParams::Create(&run_options));
1616-
}
1605+
TF_ASSIGN_OR_RETURN(
1606+
cpu::Thunk::CustomCallExecuteParams custom_call_execute_params,
1607+
cpu::Thunk::CustomCallExecuteParams::Create(&run_options));
16171608

1618-
cpu::ThreadPoolTaskRunner task_runner(
1619-
run_options.intra_op_thread_pool()->getPool());
1620-
1621-
cpu::Thunk::ExecuteParams execute_params = {
1622-
cpu_executable->function_library(),
1623-
&allocations,
1624-
cpu::GetXfeedManager(run_options.device_ordinal()),
1625-
run_options.intra_op_thread_pool(),
1626-
&task_runner,
1627-
&collective_params,
1628-
&custom_call_execute_params,
1629-
ynn_params ? &*ynn_params : nullptr,
1630-
run_options.run_id().ToInt(),
1631-
run_options.device_ordinal(),
1632-
};
1609+
std::optional<cpu::Thunk::YnnParams> ynn_params;
1610+
if (cpu_executable->has_ynn_fusions()) {
1611+
TF_ASSIGN_OR_RETURN(ynn_params,
1612+
cpu::Thunk::YnnParams::Create(&run_options));
1613+
}
16331614

1634-
thunks_execute_event = cpu_executable->thunks().Execute(execute_params);
1615+
cpu::ThreadPoolTaskRunner task_runner(
1616+
run_options.intra_op_thread_pool()->getPool());
1617+
1618+
cpu::Thunk::ExecuteParams execute_params = {
1619+
cpu_executable->function_library(),
1620+
&allocations,
1621+
cpu::GetXfeedManager(run_options.device_ordinal()),
1622+
run_options.intra_op_thread_pool(),
1623+
&task_runner,
1624+
&collective_params,
1625+
&custom_call_execute_params,
1626+
ynn_params ? &*ynn_params : nullptr,
1627+
run_options.run_id().ToInt(),
1628+
run_options.device_ordinal(),
1629+
};
1630+
1631+
auto thunks_execute_event =
1632+
cpu_executable->thunks().Execute(execute_params);
1633+
1634+
tsl::profiler::TraceMe trace([&] {
1635+
return tsl::profiler::TraceMeEncode(
1636+
"ThunkExecutor::Execute (wait for completion)",
1637+
{{"run_id", run_options.run_id().ToInt()},
1638+
{"device_ordinal", run_options.device_ordinal()}});
1639+
});
1640+
tsl::BlockUntilReady(thunks_execute_event);
1641+
return thunks_execute_event;
1642+
};
16351643

1636-
tsl::profiler::TraceMe trace([&] {
1637-
return tsl::profiler::TraceMeEncode(
1638-
"ThunkExecutor::Execute (wait for completion)",
1639-
{{"run_id", run_options.run_id().ToInt()},
1640-
{"device_ordinal", run_options.device_ordinal()}});
1641-
});
1642-
tsl::BlockUntilReady(thunks_execute_event);
1644+
absl::Status inline_result_status;
1645+
if (input_deps.events().empty() && execute_inline) {
1646+
// Synchronously call generated function or thunk sequence.
1647+
buffer_alloc.Allocate(*client()->allocator());
1648+
buffer_alloc_and_copy.AllocateAndCopy(*client()->allocator());
16431649

1644-
} else {
1645-
return Internal("CpuExecutable has no thunks.");
1646-
}
1650+
TF_ASSIGN_OR_RETURN(auto thunks_execute_event, execute_thunks());
16471651

16481652
if (thunks_execute_event.IsError()) {
16491653
inline_result_status = thunks_execute_event.GetError();
@@ -1683,16 +1687,14 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
16831687
events_avs_ref,
16841688
[cpu_executable, buffer_alloc = std::move(buffer_alloc),
16851689
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
1686-
buffer_table = std::move(buffer_table),
1687-
run_options = std::move(run_options),
1690+
execute_thunks = std::move(execute_thunks),
16881691
device_assignment = std::move(device_assignment),
16891692
cpu_run_options = std::move(cpu_run_options),
16901693
compute_reservation = std::move(compute_reservation),
16911694
tuple_index_table = std::move(tuple_index_table),
16921695
scoped_async_execution = std::move(scoped_async_execution),
16931696
input_deps_avs = std::move(input_deps).Consume(),
16941697
allocator = client()->allocator(),
1695-
eigen_device = client()->eigen_intraop_device(),
16961698
returned_future_can_be_set_event =
16971699
returned_future_can_be_set_event.CopyRef()]() mutable {
16981700
// Because `input_deps` contains the definition events of all inputs,
@@ -1710,86 +1712,13 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
17101712
return;
17111713
}
17121714
}
1713-
1714-
// Set denormal and rounding behavior to match the default TF
1715-
// ThreadPool behavior.
1716-
tsl::port::ScopedFlushDenormal flush;
1717-
tsl::port::ScopedSetRound round(FE_TONEAREST);
1718-
1719-
for (const auto& buffer_info : buffer_table) {
1720-
CHECK(buffer_info.buffer.IsAvailable());
1721-
if (buffer_info.buffer.IsError()) {
1722-
scoped_async_execution.SetError(
1723-
Internal("Error preparing computation: %s",
1724-
buffer_info.buffer.GetError().message()));
1725-
returned_future_can_be_set_event.SetStateConcrete();
1726-
return;
1727-
}
1728-
}
1729-
absl::Status status;
1730-
if (cpu_executable->has_thunks()) {
1731-
// Call interpreted thunk sequence implementing XLA executable.
1732-
absl::InlinedVector<MaybeOwningDeviceAddress, 8> buffer_device_mem;
1733-
buffer_device_mem.reserve(buffer_table.size());
1734-
for (const auto& buffer_info : buffer_table) {
1735-
buffer_device_mem.emplace_back(
1736-
se::DeviceAddressBase(buffer_info.buffer->untyped_data(),
1737-
buffer_info.buffer->size_bytes()));
1738-
}
1739-
1740-
cpu::BufferAllocations allocations(buffer_device_mem);
1741-
1742-
absl::StatusOr<cpu::Thunk::CollectiveExecuteParams>
1743-
collective_params =
1744-
cpu::Thunk::CollectiveExecuteParams::Create(&run_options);
1745-
1746-
absl::StatusOr<cpu::Thunk::CustomCallExecuteParams>
1747-
custom_call_params =
1748-
cpu::Thunk::CustomCallExecuteParams::Create(&run_options);
1749-
1750-
absl::StatusOr<std::optional<cpu::Thunk::YnnParams>> ynn_params(
1751-
std::nullopt);
1752-
if (cpu_executable->has_ynn_fusions()) {
1753-
ynn_params = cpu::Thunk::YnnParams::Create(&run_options);
1715+
auto status = [&]() -> absl::Status {
1716+
TF_ASSIGN_OR_RETURN(auto thunks_execute_event, execute_thunks());
1717+
if (thunks_execute_event.IsError()) {
1718+
return thunks_execute_event.GetError();
17541719
}
1755-
1756-
cpu::ThreadPoolTaskRunner task_runner(
1757-
run_options.intra_op_thread_pool()->getPool());
1758-
1759-
if (collective_params.ok()) {
1760-
cpu::Thunk::ExecuteParams execute_params = {
1761-
cpu_executable->function_library(),
1762-
&allocations,
1763-
cpu::GetXfeedManager(run_options.device_ordinal()),
1764-
run_options.intra_op_thread_pool(),
1765-
&task_runner,
1766-
&*collective_params,
1767-
&*custom_call_params,
1768-
*ynn_params ? &**ynn_params : nullptr,
1769-
run_options.run_id().ToInt(),
1770-
run_options.device_ordinal(),
1771-
};
1772-
1773-
auto thunks_execute_event =
1774-
cpu_executable->thunks().Execute(execute_params);
1775-
1776-
tsl::profiler::TraceMe trace([&] {
1777-
return tsl::profiler::TraceMeEncode(
1778-
"ThunkExecutor::Execute (wait for completion)",
1779-
{{"run_id", run_options.run_id().ToInt()},
1780-
{"device_ordinal", run_options.device_ordinal()}});
1781-
});
1782-
tsl::BlockUntilReady(thunks_execute_event);
1783-
status = thunks_execute_event.IsError()
1784-
? thunks_execute_event.GetError()
1785-
: absl::OkStatus();
1786-
} else {
1787-
status = collective_params.status();
1788-
}
1789-
1790-
} else {
1791-
status = Internal("CpuExecutable has no thunks.");
1792-
}
1720+
return absl::OkStatus();
1721+
}();
17931722

17941723
if (!status.ok()) {
17951724
// CPU computation fails with an error.

0 commit comments

Comments
 (0)