@@ -1507,7 +1507,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
15071507 run_options.set_device_assignment (device_assignment.get ());
15081508 run_options.set_intra_op_thread_pool (client_->eigen_intraop_device ());
15091509
1510- auto cpu_run_options = std::make_shared <cpu::CpuExecutableRunOptions>();
1510+ auto cpu_run_options = std::make_unique <cpu::CpuExecutableRunOptions>();
15111511 run_options.set_cpu_executable_run_options (cpu_run_options.get ());
15121512
15131513 const CpuExecuteContext* cpu_execute_context =
@@ -1567,83 +1567,87 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
15671567 execute_inline = true ;
15681568 }
15691569
1570- absl::Status inline_result_status;
1571- if (input_deps. events (). empty () && execute_inline) {
1572- // Synchronously call generated function or thunk sequence.
1573-
1570+ auto execute_thunks = [cpu_executable, buffer_table = std::move (buffer_table),
1571+ eigen_device = client ()-> eigen_intraop_device (),
1572+ run_options = std::move (run_options)]()
1573+ -> absl::StatusOr<tsl::AsyncValueRef<cpu::Thunk::ExecuteEvent>> {
15741574 // Set denormal and rounding behavior to match the default TF
15751575 // ThreadPool behavior.
15761576 tsl::port::ScopedFlushDenormal flush;
15771577 tsl::port::ScopedSetRound round (FE_TONEAREST);
15781578
1579- // Execution status for XLA:CPU thunk runtime.
1580- tsl::AsyncValueRef<cpu::Thunk::ExecuteEvent> thunks_execute_event;
1581-
15821579 // Immediately allocate memory and prepare for computation.
1583- buffer_alloc.Allocate (*client ()->allocator ());
1584- buffer_alloc_and_copy.AllocateAndCopy (*client ()->allocator ());
15851580 for (const auto & buffer_info : buffer_table) {
15861581 CHECK (buffer_info.buffer .IsAvailable ());
15871582 if (buffer_info.buffer .IsError ()) {
15881583 return buffer_info.buffer .GetError ();
15891584 }
15901585 }
15911586
1592- if (cpu_executable->has_thunks ()) {
1593- // Call interpreted thunk sequence implementing XLA executable.
1594- absl::InlinedVector<MaybeOwningDeviceAddress, 8 > buffer_device_mem;
1595- buffer_device_mem. reserve (buffer_table. size ());
1596- for ( const auto & buffer_info : buffer_table) {
1597- buffer_device_mem.emplace_back (
1598- se::DeviceAddressBase (buffer_info. buffer -> untyped_data (),
1599- buffer_info. buffer -> size_bytes ()));
1600- }
1601-
1602- cpu::BufferAllocations allocations (buffer_device_mem);
1587+ if (! cpu_executable->has_thunks ()) {
1588+ return Internal ( " CpuExecutable has no thunks. " );
1589+ }
1590+ // Call interpreted thunk sequence implementing XLA executable.
1591+ absl::InlinedVector<MaybeOwningDeviceAddress, 8 > buffer_device_mem;
1592+ buffer_device_mem.reserve (buffer_table. size ());
1593+ for ( const auto & buffer_info : buffer_table) {
1594+ buffer_device_mem. emplace_back (
1595+ se::DeviceAddressBase (buffer_info. buffer -> untyped_data (),
1596+ buffer_info. buffer -> size_bytes ()));
1597+ }
16031598
1604- TF_ASSIGN_OR_RETURN (
1605- cpu::Thunk::CollectiveExecuteParams collective_params,
1606- cpu::Thunk::CollectiveExecuteParams::Create (&run_options));
1599+ cpu::BufferAllocations allocations (buffer_device_mem);
16071600
1608- TF_ASSIGN_OR_RETURN (
1609- cpu::Thunk::CustomCallExecuteParams custom_call_execute_params ,
1610- cpu::Thunk::CustomCallExecuteParams ::Create (&run_options));
1601+ TF_ASSIGN_OR_RETURN (
1602+ cpu::Thunk::CollectiveExecuteParams collective_params ,
1603+ cpu::Thunk::CollectiveExecuteParams ::Create (&run_options));
16111604
1612- std::optional<cpu::Thunk::YnnParams> ynn_params;
1613- if (cpu_executable->has_ynn_fusions ()) {
1614- TF_ASSIGN_OR_RETURN (ynn_params,
1615- cpu::Thunk::YnnParams::Create (&run_options));
1616- }
1605+ TF_ASSIGN_OR_RETURN (
1606+ cpu::Thunk::CustomCallExecuteParams custom_call_execute_params,
1607+ cpu::Thunk::CustomCallExecuteParams::Create (&run_options));
16171608
1618- cpu::ThreadPoolTaskRunner task_runner (
1619- run_options.intra_op_thread_pool ()->getPool ());
1620-
1621- cpu::Thunk::ExecuteParams execute_params = {
1622- cpu_executable->function_library (),
1623- &allocations,
1624- cpu::GetXfeedManager (run_options.device_ordinal ()),
1625- run_options.intra_op_thread_pool (),
1626- &task_runner,
1627- &collective_params,
1628- &custom_call_execute_params,
1629- ynn_params ? &*ynn_params : nullptr ,
1630- run_options.run_id ().ToInt (),
1631- run_options.device_ordinal (),
1632- };
1609+ std::optional<cpu::Thunk::YnnParams> ynn_params;
1610+ if (cpu_executable->has_ynn_fusions ()) {
1611+ TF_ASSIGN_OR_RETURN (ynn_params,
1612+ cpu::Thunk::YnnParams::Create (&run_options));
1613+ }
16331614
1634- thunks_execute_event = cpu_executable->thunks ().Execute (execute_params);
1615+ cpu::ThreadPoolTaskRunner task_runner (
1616+ run_options.intra_op_thread_pool ()->getPool ());
1617+
1618+ cpu::Thunk::ExecuteParams execute_params = {
1619+ cpu_executable->function_library (),
1620+ &allocations,
1621+ cpu::GetXfeedManager (run_options.device_ordinal ()),
1622+ run_options.intra_op_thread_pool (),
1623+ &task_runner,
1624+ &collective_params,
1625+ &custom_call_execute_params,
1626+ ynn_params ? &*ynn_params : nullptr ,
1627+ run_options.run_id ().ToInt (),
1628+ run_options.device_ordinal (),
1629+ };
1630+
1631+ auto thunks_execute_event =
1632+ cpu_executable->thunks ().Execute (execute_params);
1633+
1634+ tsl::profiler::TraceMe trace ([&] {
1635+ return tsl::profiler::TraceMeEncode (
1636+ " ThunkExecutor::Execute (wait for completion)" ,
1637+ {{" run_id" , run_options.run_id ().ToInt ()},
1638+ {" device_ordinal" , run_options.device_ordinal ()}});
1639+ });
1640+ tsl::BlockUntilReady (thunks_execute_event);
1641+ return thunks_execute_event;
1642+ };
16351643
1636- tsl::profiler::TraceMe trace ([&] {
1637- return tsl::profiler::TraceMeEncode (
1638- " ThunkExecutor::Execute (wait for completion)" ,
1639- {{" run_id" , run_options.run_id ().ToInt ()},
1640- {" device_ordinal" , run_options.device_ordinal ()}});
1641- });
1642- tsl::BlockUntilReady (thunks_execute_event);
1644+ absl::Status inline_result_status;
1645+ if (input_deps.events ().empty () && execute_inline) {
1646+ // Synchronously call generated function or thunk sequence.
1647+ buffer_alloc.Allocate (*client ()->allocator ());
1648+ buffer_alloc_and_copy.AllocateAndCopy (*client ()->allocator ());
16431649
1644- } else {
1645- return Internal (" CpuExecutable has no thunks." );
1646- }
1650+ TF_ASSIGN_OR_RETURN (auto thunks_execute_event, execute_thunks ());
16471651
16481652 if (thunks_execute_event.IsError ()) {
16491653 inline_result_status = thunks_execute_event.GetError ();
@@ -1683,16 +1687,14 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
16831687 events_avs_ref,
16841688 [cpu_executable, buffer_alloc = std::move (buffer_alloc),
16851689 buffer_alloc_and_copy = std::move (buffer_alloc_and_copy),
1686- buffer_table = std::move (buffer_table),
1687- run_options = std::move (run_options),
1690+ execute_thunks = std::move (execute_thunks),
16881691 device_assignment = std::move (device_assignment),
16891692 cpu_run_options = std::move (cpu_run_options),
16901693 compute_reservation = std::move (compute_reservation),
16911694 tuple_index_table = std::move (tuple_index_table),
16921695 scoped_async_execution = std::move (scoped_async_execution),
16931696 input_deps_avs = std::move (input_deps).Consume (),
16941697 allocator = client ()->allocator (),
1695- eigen_device = client ()->eigen_intraop_device (),
16961698 returned_future_can_be_set_event =
16971699 returned_future_can_be_set_event.CopyRef ()]() mutable {
16981700 // Because `input_deps` contains the definition events of all inputs,
@@ -1710,86 +1712,13 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
17101712 return ;
17111713 }
17121714 }
1713-
1714- // Set denormal and rounding behavior to match the default TF
1715- // ThreadPool behavior.
1716- tsl::port::ScopedFlushDenormal flush;
1717- tsl::port::ScopedSetRound round (FE_TONEAREST);
1718-
1719- for (const auto & buffer_info : buffer_table) {
1720- CHECK (buffer_info.buffer .IsAvailable ());
1721- if (buffer_info.buffer .IsError ()) {
1722- scoped_async_execution.SetError (
1723- Internal (" Error preparing computation: %s" ,
1724- buffer_info.buffer .GetError ().message ()));
1725- returned_future_can_be_set_event.SetStateConcrete ();
1726- return ;
1727- }
1728- }
1729- absl::Status status;
1730- if (cpu_executable->has_thunks ()) {
1731- // Call interpreted thunk sequence implementing XLA executable.
1732- absl::InlinedVector<MaybeOwningDeviceAddress, 8 > buffer_device_mem;
1733- buffer_device_mem.reserve (buffer_table.size ());
1734- for (const auto & buffer_info : buffer_table) {
1735- buffer_device_mem.emplace_back (
1736- se::DeviceAddressBase (buffer_info.buffer ->untyped_data (),
1737- buffer_info.buffer ->size_bytes ()));
1738- }
1739-
1740- cpu::BufferAllocations allocations (buffer_device_mem);
1741-
1742- absl::StatusOr<cpu::Thunk::CollectiveExecuteParams>
1743- collective_params =
1744- cpu::Thunk::CollectiveExecuteParams::Create (&run_options);
1745-
1746- absl::StatusOr<cpu::Thunk::CustomCallExecuteParams>
1747- custom_call_params =
1748- cpu::Thunk::CustomCallExecuteParams::Create (&run_options);
1749-
1750- absl::StatusOr<std::optional<cpu::Thunk::YnnParams>> ynn_params (
1751- std::nullopt );
1752- if (cpu_executable->has_ynn_fusions ()) {
1753- ynn_params = cpu::Thunk::YnnParams::Create (&run_options);
1715+ auto status = [&]() -> absl::Status {
1716+ TF_ASSIGN_OR_RETURN (auto thunks_execute_event, execute_thunks ());
1717+ if (thunks_execute_event.IsError ()) {
1718+ return thunks_execute_event.GetError ();
17541719 }
1755-
1756- cpu::ThreadPoolTaskRunner task_runner (
1757- run_options.intra_op_thread_pool ()->getPool ());
1758-
1759- if (collective_params.ok ()) {
1760- cpu::Thunk::ExecuteParams execute_params = {
1761- cpu_executable->function_library (),
1762- &allocations,
1763- cpu::GetXfeedManager (run_options.device_ordinal ()),
1764- run_options.intra_op_thread_pool (),
1765- &task_runner,
1766- &*collective_params,
1767- &*custom_call_params,
1768- *ynn_params ? &**ynn_params : nullptr ,
1769- run_options.run_id ().ToInt (),
1770- run_options.device_ordinal (),
1771- };
1772-
1773- auto thunks_execute_event =
1774- cpu_executable->thunks ().Execute (execute_params);
1775-
1776- tsl::profiler::TraceMe trace ([&] {
1777- return tsl::profiler::TraceMeEncode (
1778- " ThunkExecutor::Execute (wait for completion)" ,
1779- {{" run_id" , run_options.run_id ().ToInt ()},
1780- {" device_ordinal" , run_options.device_ordinal ()}});
1781- });
1782- tsl::BlockUntilReady (thunks_execute_event);
1783- status = thunks_execute_event.IsError ()
1784- ? thunks_execute_event.GetError ()
1785- : absl::OkStatus ();
1786- } else {
1787- status = collective_params.status ();
1788- }
1789-
1790- } else {
1791- status = Internal (" CpuExecutable has no thunks." );
1792- }
1720+ return absl::OkStatus ();
1721+ }();
17931722
17941723 if (!status.ok ()) {
17951724 // CPU computation fails with an error.
0 commit comments