Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions libc/shared/rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ namespace rpc {
#define __scoped_atomic_thread_fence(ord, scp) __atomic_thread_fence(ord)
#endif

/// Generic codes that can be used whem implementing the server.
enum Status {
SUCCESS = 0x0,
ERROR = 0x1000,
UNHANDLED_OPCODE = 0x1001,
};

/// A fixed size channel used to communicate between the RPC client and server.
struct Buffer {
uint64_t data[8];
Expand Down
209 changes: 88 additions & 121 deletions libc/utils/gpu/loader/Loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "include/llvm-libc-types/rpc_opcodes_t.h"
#include "include/llvm-libc-types/test_rpc_opcodes_t.h"
#include "shared/rpc.h"

#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -103,129 +104,95 @@ inline void handle_error_impl(const char *file, int32_t line, const char *msg) {
fprintf(stderr, "%s:%d:0: Error: %s\n", file, line, msg);
exit(EXIT_FAILURE);
}

inline void handle_error_impl(const char *file, int32_t line,
rpc_status_t err) {
fprintf(stderr, "%s:%d:0: Error: %d\n", file, line, err);
exit(EXIT_FAILURE);
}
#define handle_error(X) handle_error_impl(__FILE__, __LINE__, X)

template <uint32_t lane_size>
inline void register_rpc_callbacks(rpc_device_t device) {
static_assert(lane_size == 32 || lane_size == 64, "Invalid Lane size");
// Register the ping test for the `libc` tests.
rpc_register_callback(
device, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
[](rpc_port_t port, void *data) {
rpc_recv_and_send(
port,
[](rpc_buffer_t *buffer, void *data) {
reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
},
data);
},
nullptr);

// Register the interface test callbacks.
rpc_register_callback(
device, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
[](rpc_port_t port, void *data) {
uint64_t cnt = 0;
bool end_with_recv;
rpc_recv(
port,
[](rpc_buffer_t *buffer, void *data) {
*reinterpret_cast<bool *>(data) = buffer->data[0];
},
&end_with_recv);
rpc_recv(
port,
[](rpc_buffer_t *buffer, void *data) {
*reinterpret_cast<uint64_t *>(data) = buffer->data[0];
},
&cnt);
rpc_send(
port,
[](rpc_buffer_t *buffer, void *data) {
uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
buffer->data[0] = cnt = cnt + 1;
},
&cnt);
rpc_recv(
port,
[](rpc_buffer_t *buffer, void *data) {
*reinterpret_cast<uint64_t *>(data) = buffer->data[0];
},
&cnt);
rpc_send(
port,
[](rpc_buffer_t *buffer, void *data) {
uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
buffer->data[0] = cnt = cnt + 1;
},
&cnt);
rpc_recv(
port,
[](rpc_buffer_t *buffer, void *data) {
*reinterpret_cast<uint64_t *>(data) = buffer->data[0];
},
&cnt);
rpc_recv(
port,
[](rpc_buffer_t *buffer, void *data) {
*reinterpret_cast<uint64_t *>(data) = buffer->data[0];
},
&cnt);
rpc_send(
port,
[](rpc_buffer_t *buffer, void *data) {
uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
buffer->data[0] = cnt = cnt + 1;
},
&cnt);
rpc_send(
port,
[](rpc_buffer_t *buffer, void *data) {
uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
buffer->data[0] = cnt = cnt + 1;
},
&cnt);
if (end_with_recv)
rpc_recv(
port,
[](rpc_buffer_t *buffer, void *data) {
*reinterpret_cast<uint64_t *>(data) = buffer->data[0];
},
&cnt);
else
rpc_send(
port,
[](rpc_buffer_t *buffer, void *data) {
uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
buffer->data[0] = cnt = cnt + 1;
},
&cnt);
},
nullptr);

// Register the stream test handler.
rpc_register_callback(
device, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
[](rpc_port_t port, void *data) {
uint64_t sizes[lane_size] = {0};
void *dst[lane_size] = {nullptr};
rpc_recv_n(
port, dst, sizes,
[](uint64_t size, void *) -> void * { return new char[size]; },
nullptr);
rpc_send_n(port, dst, sizes);
for (uint64_t i = 0; i < lane_size; ++i) {
if (dst[i])
delete[] reinterpret_cast<uint8_t *>(dst[i]);
}
},
nullptr);
template <uint32_t num_lanes, typename Alloc, typename Free>
inline uint32_t handle_server(rpc::Server &server, uint32_t index,
Alloc &&alloc, Free &&free) {
auto port = server.try_open(num_lanes, index);
if (!port)
return 0;
index = port->get_index() + 1;

int status = rpc::SUCCESS;
switch (port->get_opcode()) {
case RPC_TEST_INCREMENT: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
});
break;
}
case RPC_TEST_INTERFACE: {
bool end_with_recv;
uint64_t cnt;
port->recv([&](rpc::Buffer *buffer, uint32_t) {
end_with_recv = buffer->data[0];
});
port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
port->send([&](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = cnt = cnt + 1;
});
port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
port->send([&](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = cnt = cnt + 1;
});
port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
port->send([&](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = cnt = cnt + 1;
});
port->send([&](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = cnt = cnt + 1;
});
if (end_with_recv)
port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
else
port->send([&](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = cnt = cnt + 1;
});

break;
}
case RPC_TEST_STREAM: {
uint64_t sizes[num_lanes] = {0};
void *dst[num_lanes] = {nullptr};
port->recv_n(dst, sizes,
[](uint64_t size) -> void * { return new char[size]; });
port->send_n(dst, sizes);
for (uint64_t i = 0; i < num_lanes; ++i) {
if (dst[i])
delete[] reinterpret_cast<uint8_t *>(dst[i]);
}
break;
}
case RPC_TEST_NOOP: {
port->recv([&](rpc::Buffer *, uint32_t) {});
break;
}
case RPC_MALLOC: {
port->recv_and_send([&](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = reinterpret_cast<uintptr_t>(alloc(buffer->data[0]));
});
break;
}
case RPC_FREE: {
port->recv([&](rpc::Buffer *buffer, uint32_t) {
free(reinterpret_cast<void *>(buffer->data[0]));
});
break;
}
default:
status = libc_handle_rpc_port(&*port, num_lanes);
break;
}

// Handle all of the `libc` specific opcodes.
if (status != rpc::SUCCESS)
handle_error("Error handling RPC server");

port->close();

return index;
}

#endif
Loading
Loading