Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 69 additions & 140 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
run_options.set_device_assignment(device_assignment.get());
run_options.set_intra_op_thread_pool(client_->eigen_intraop_device());

auto cpu_run_options = std::make_shared<cpu::CpuExecutableRunOptions>();
auto cpu_run_options = std::make_unique<cpu::CpuExecutableRunOptions>();
run_options.set_cpu_executable_run_options(cpu_run_options.get());

const CpuExecuteContext* cpu_execute_context =
Expand Down Expand Up @@ -1567,83 +1567,87 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
execute_inline = true;
}

absl::Status inline_result_status;
if (input_deps.events().empty() && execute_inline) {
// Synchronously call generated function or thunk sequence.

auto execute_thunks = [cpu_executable, buffer_table = std::move(buffer_table),
eigen_device = client()->eigen_intraop_device(),
run_options = std::move(run_options)]()
-> absl::StatusOr<tsl::AsyncValueRef<cpu::Thunk::ExecuteEvent>> {
// Set denormal and rounding behavior to match the default TF
// ThreadPool behavior.
tsl::port::ScopedFlushDenormal flush;
tsl::port::ScopedSetRound round(FE_TONEAREST);

// Execution status for XLA:CPU thunk runtime.
tsl::AsyncValueRef<cpu::Thunk::ExecuteEvent> thunks_execute_event;

// Immediately allocate memory and prepare for computation.
buffer_alloc.Allocate(*client()->allocator());
buffer_alloc_and_copy.AllocateAndCopy(*client()->allocator());
for (const auto& buffer_info : buffer_table) {
CHECK(buffer_info.buffer.IsAvailable());
if (buffer_info.buffer.IsError()) {
return buffer_info.buffer.GetError();
}
}

if (cpu_executable->has_thunks()) {
// Call interpreted thunk sequence implementing XLA executable.
absl::InlinedVector<MaybeOwningDeviceAddress, 8> buffer_device_mem;
buffer_device_mem.reserve(buffer_table.size());
for (const auto& buffer_info : buffer_table) {
buffer_device_mem.emplace_back(
se::DeviceAddressBase(buffer_info.buffer->untyped_data(),
buffer_info.buffer->size_bytes()));
}

cpu::BufferAllocations allocations(buffer_device_mem);
if (!cpu_executable->has_thunks()) {
return Internal("CpuExecutable has no thunks.");
}
// Call interpreted thunk sequence implementing XLA executable.
absl::InlinedVector<MaybeOwningDeviceAddress, 8> buffer_device_mem;
buffer_device_mem.reserve(buffer_table.size());
for (const auto& buffer_info : buffer_table) {
buffer_device_mem.emplace_back(
se::DeviceAddressBase(buffer_info.buffer->untyped_data(),
buffer_info.buffer->size_bytes()));
}

TF_ASSIGN_OR_RETURN(
cpu::Thunk::CollectiveExecuteParams collective_params,
cpu::Thunk::CollectiveExecuteParams::Create(&run_options));
cpu::BufferAllocations allocations(buffer_device_mem);

TF_ASSIGN_OR_RETURN(
cpu::Thunk::CustomCallExecuteParams custom_call_execute_params,
cpu::Thunk::CustomCallExecuteParams::Create(&run_options));
TF_ASSIGN_OR_RETURN(
cpu::Thunk::CollectiveExecuteParams collective_params,
cpu::Thunk::CollectiveExecuteParams::Create(&run_options));

std::optional<cpu::Thunk::YnnParams> ynn_params;
if (cpu_executable->has_ynn_fusions()) {
TF_ASSIGN_OR_RETURN(ynn_params,
cpu::Thunk::YnnParams::Create(&run_options));
}
TF_ASSIGN_OR_RETURN(
cpu::Thunk::CustomCallExecuteParams custom_call_execute_params,
cpu::Thunk::CustomCallExecuteParams::Create(&run_options));

cpu::ThreadPoolTaskRunner task_runner(
run_options.intra_op_thread_pool()->getPool());

cpu::Thunk::ExecuteParams execute_params = {
cpu_executable->function_library(),
&allocations,
cpu::GetXfeedManager(run_options.device_ordinal()),
run_options.intra_op_thread_pool(),
&task_runner,
&collective_params,
&custom_call_execute_params,
ynn_params ? &*ynn_params : nullptr,
run_options.run_id().ToInt(),
run_options.device_ordinal(),
};
std::optional<cpu::Thunk::YnnParams> ynn_params;
if (cpu_executable->has_ynn_fusions()) {
TF_ASSIGN_OR_RETURN(ynn_params,
cpu::Thunk::YnnParams::Create(&run_options));
}

thunks_execute_event = cpu_executable->thunks().Execute(execute_params);
cpu::ThreadPoolTaskRunner task_runner(
run_options.intra_op_thread_pool()->getPool());

cpu::Thunk::ExecuteParams execute_params = {
cpu_executable->function_library(),
&allocations,
cpu::GetXfeedManager(run_options.device_ordinal()),
run_options.intra_op_thread_pool(),
&task_runner,
&collective_params,
&custom_call_execute_params,
ynn_params ? &*ynn_params : nullptr,
run_options.run_id().ToInt(),
run_options.device_ordinal(),
};

auto thunks_execute_event =
cpu_executable->thunks().Execute(execute_params);

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"ThunkExecutor::Execute (wait for completion)",
{{"run_id", run_options.run_id().ToInt()},
{"device_ordinal", run_options.device_ordinal()}});
});
tsl::BlockUntilReady(thunks_execute_event);
return thunks_execute_event;
};

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"ThunkExecutor::Execute (wait for completion)",
{{"run_id", run_options.run_id().ToInt()},
{"device_ordinal", run_options.device_ordinal()}});
});
tsl::BlockUntilReady(thunks_execute_event);
absl::Status inline_result_status;
if (input_deps.events().empty() && execute_inline) {
// Synchronously call generated function or thunk sequence.
buffer_alloc.Allocate(*client()->allocator());
buffer_alloc_and_copy.AllocateAndCopy(*client()->allocator());

} else {
return Internal("CpuExecutable has no thunks.");
}
TF_ASSIGN_OR_RETURN(auto thunks_execute_event, execute_thunks());

if (thunks_execute_event.IsError()) {
inline_result_status = thunks_execute_event.GetError();
Expand Down Expand Up @@ -1683,16 +1687,14 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
events_avs_ref,
[cpu_executable, buffer_alloc = std::move(buffer_alloc),
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
buffer_table = std::move(buffer_table),
run_options = std::move(run_options),
execute_thunks = std::move(execute_thunks),
device_assignment = std::move(device_assignment),
cpu_run_options = std::move(cpu_run_options),
compute_reservation = std::move(compute_reservation),
tuple_index_table = std::move(tuple_index_table),
scoped_async_execution = std::move(scoped_async_execution),
input_deps_avs = std::move(input_deps).Consume(),
allocator = client()->allocator(),
eigen_device = client()->eigen_intraop_device(),
returned_future_can_be_set_event =
returned_future_can_be_set_event.CopyRef()]() mutable {
// Because `input_deps` contains the definition events of all inputs,
Expand All @@ -1710,86 +1712,13 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
return;
}
}

// Set denormal and rounding behavior to match the default TF
// ThreadPool behavior.
tsl::port::ScopedFlushDenormal flush;
tsl::port::ScopedSetRound round(FE_TONEAREST);

for (const auto& buffer_info : buffer_table) {
CHECK(buffer_info.buffer.IsAvailable());
if (buffer_info.buffer.IsError()) {
scoped_async_execution.SetError(
Internal("Error preparing computation: %s",
buffer_info.buffer.GetError().message()));
returned_future_can_be_set_event.SetStateConcrete();
return;
}
}
absl::Status status;
if (cpu_executable->has_thunks()) {
// Call interpreted thunk sequence implementing XLA executable.
absl::InlinedVector<MaybeOwningDeviceAddress, 8> buffer_device_mem;
buffer_device_mem.reserve(buffer_table.size());
for (const auto& buffer_info : buffer_table) {
buffer_device_mem.emplace_back(
se::DeviceAddressBase(buffer_info.buffer->untyped_data(),
buffer_info.buffer->size_bytes()));
}

cpu::BufferAllocations allocations(buffer_device_mem);

absl::StatusOr<cpu::Thunk::CollectiveExecuteParams>
collective_params =
cpu::Thunk::CollectiveExecuteParams::Create(&run_options);

absl::StatusOr<cpu::Thunk::CustomCallExecuteParams>
custom_call_params =
cpu::Thunk::CustomCallExecuteParams::Create(&run_options);

absl::StatusOr<std::optional<cpu::Thunk::YnnParams>> ynn_params(
std::nullopt);
if (cpu_executable->has_ynn_fusions()) {
ynn_params = cpu::Thunk::YnnParams::Create(&run_options);
auto status = [&]() -> absl::Status {
TF_ASSIGN_OR_RETURN(auto thunks_execute_event, execute_thunks());
if (thunks_execute_event.IsError()) {
return thunks_execute_event.GetError();
}

cpu::ThreadPoolTaskRunner task_runner(
run_options.intra_op_thread_pool()->getPool());

if (collective_params.ok()) {
cpu::Thunk::ExecuteParams execute_params = {
cpu_executable->function_library(),
&allocations,
cpu::GetXfeedManager(run_options.device_ordinal()),
run_options.intra_op_thread_pool(),
&task_runner,
&*collective_params,
&*custom_call_params,
*ynn_params ? &**ynn_params : nullptr,
run_options.run_id().ToInt(),
run_options.device_ordinal(),
};

auto thunks_execute_event =
cpu_executable->thunks().Execute(execute_params);

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"ThunkExecutor::Execute (wait for completion)",
{{"run_id", run_options.run_id().ToInt()},
{"device_ordinal", run_options.device_ordinal()}});
});
tsl::BlockUntilReady(thunks_execute_event);
status = thunks_execute_event.IsError()
? thunks_execute_event.GetError()
: absl::OkStatus();
} else {
status = collective_params.status();
}

} else {
status = Internal("CpuExecutable has no thunks.");
}
return absl::OkStatus();
}();

if (!status.ok()) {
// CPU computation fails with an error.
Expand Down
Loading