Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions xprof/pywrap/_pywrap_profiler_plugin.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@
# ==============================================================================

def monitor(arg0: str, arg1: int, arg2: int, arg3: bool) -> str: ...
def trace(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: int, arg5: int, arg6: dict) -> None: ...
def trace(
arg0: str,
arg1: str,
arg2: str,
arg3: bool,
arg4: int,
arg5: int,
arg6: dict,
) -> None: ...
def start_continuous_profiling(service_addr: str, options: dict) -> None: ...
def get_snapshot(service_addr: str, logdir: str) -> None: ...

def xspace_to_tools_data(arg0: list, arg1: str, arg2: dict = ...) -> tuple: ...
def xspace_to_tools_data_from_byte_string(arg0: list, arg1: list, arg2: str, arg3: dict) -> tuple: ...
def xspace_to_tools_data_from_byte_string(
arg0: list, arg1: list, arg2: str, arg3: dict
) -> tuple: ...

def start_grpc_server(port: int) -> None: ...
def initialize_stubs(worker_service_addresses: str) -> None: ...

23 changes: 23 additions & 0 deletions xprof/pywrap/profiler_plugin_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -116,6 +117,28 @@ absl::Status Monitor(const char* service_addr, int duration_ms,
return absl::OkStatus();
}

absl::Status StartContinuousProfiling(const char* service_addr,
const ToolOptions& tool_options) {
LOG(INFO) << "StartContinuousProfiling";
TF_RETURN_IF_ERROR(tsl::profiler::ValidateHostPortPair(service_addr));
tensorflow::RemoteProfilerSessionManagerOptions options;
bool is_cloud_tpu_session;
// Even though the duration is set to 2 seconds, the profiling will continue
// until GetSnapshot is called, it is only done since
// GetRemoteSessionManagerOptionsLocked requires a duration.
const int32_t kContinuousProfilingdurationMs = 2000;
options = tsl::profiler::GetRemoteSessionManagerOptionsLocked(
service_addr, "", "", false, kContinuousProfilingdurationMs, tool_options,
&is_cloud_tpu_session);
return tsl::profiler::StartContinuousProfiling(service_addr, options);
}

absl::Status GetSnapshot(const char* service_addr, const char* logdir) {
LOG(INFO) << "GetSnapshot";
TF_RETURN_IF_ERROR(tsl::profiler::ValidateHostPortPair(service_addr));
return tsl::profiler::GetSnapshot(service_addr, logdir);
}

static absl::once_flag server_init_flag;

void StartGrpcServer(int port) {
Expand Down
6 changes: 6 additions & 0 deletions xprof/pywrap/profiler_plugin_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ absl::Status Monitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
tsl::string* result);

absl::Status StartContinuousProfiling(
const char* service_addr,
const tensorflow::profiler::ToolOptions& tool_options);

absl::Status GetSnapshot(const char* service_addr, const char* logdir);

absl::StatusOr<std::pair<std::string, bool>> XSpaceToToolsData(
std::vector<std::string> xspace_paths, const std::string& tool_name,
const tensorflow::profiler::ToolOptions& tool_options);
Expand Down
21 changes: 21 additions & 0 deletions xprof/pywrap/pywrap_profiler_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) {
return content;
});

m.def("start_continuous_profiling",
[](const char* service_addr, py::dict options) {
absl::Status status;
ToolOptions tool_options = ToolOptionsFromPythonDict(options);
{
py::gil_scoped_release release;
status = xprof::pywrap::StartContinuousProfiling(service_addr,
tool_options);
}
xla::ThrowIfError(status);
});

m.def("get_snapshot", [](const char* service_addr, const char* logdir) {
absl::Status status;
{
py::gil_scoped_release release;
status = xprof::pywrap::GetSnapshot(service_addr, logdir);
}
xla::ThrowIfError(status);
});

m.def(
"xspace_to_tools_data",
[](const py::list& xspace_path_list, const py::str& py_tool_name,
Expand Down