Skip to content

Commit 69512e6

Browse files
committed
chore: logs
1 parent 096d821 commit 69512e6

File tree

5 files changed

+81
-18
lines changed

5 files changed

+81
-18
lines changed

codegen/gen_client.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ extern conn_t *rpc_client_get_connection(unsigned int index);
1919
int is_unified_pointer(conn_t *conn, void *arg);
2020
int maybe_copy_unified_arg(conn_t *conn, void *arg, enum cudaMemcpyKind kind);
2121
extern void rpc_close(conn_t *conn);
22+
extern void increment_host_nodes();
23+
extern void wait_for_callbacks();
2224
void invoke_host_funcs(const int index, void *udata);
2325

2426
nvmlReturn_t nvmlInit_v2() {
@@ -23092,6 +23094,7 @@ cudaError_t cudaGraphAddHostNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph,
2309223094
size_t numDependencies,
2309323095
const struct cudaHostNodeParams *pNodeParams) {
2309423096
conn_t *conn = rpc_client_get_connection(0);
23097+
increment_host_nodes();
2309523098
printf("hmmmm %p\n", pNodeParams->fn);
2309623099
if (maybe_copy_unified_arg(conn, (void *)&numDependencies,
2309723100
cudaMemcpyHostToDevice) < 0)
@@ -24972,6 +24975,8 @@ cudaError_t cudaGraphLaunch(cudaGraphExec_t graphExec, cudaStream_t stream) {
2497224975
rpc_read_end(conn) < 0)
2497324976
return cudaErrorDevicesUnavailable;
2497424977

24978+
wait_for_callbacks();
24979+
2497524980
if (maybe_copy_unified_arg(conn, (void *)&graphExec, cudaMemcpyDeviceToHost) <
2497624981
0)
2497724982
return cudaErrorDevicesUnavailable;

codegen/gen_server.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21194,8 +21194,8 @@ int handle_cudaGraphAddHostNode(conn_t *conn) {
2119421194

2119521195
hostFnData = (callBackData_t *)malloc(sizeof(callBackData_t));
2119621196
// assign the previous function pointer so we can map back to it
21197+
store_conn(conn);
2119721198
hostFnData->callback = pNodeParams.fn;
21198-
hostFnData->conn = conn;
2119921199
hostFnData->data = pNodeParams.userData;
2120021200

2120121201
pNodeParams.fn = invoke_host_func;
@@ -22714,11 +22714,15 @@ int handle_cudaGraphLaunch(conn_t *conn) {
2271422714

2271522715
scuda_intercept_result = cudaGraphLaunch(graphExec, stream);
2271622716

22717+
std::cout << "RESPONDING TO CUDAGRAPH" << std::endl;
22718+
2271722719
if (rpc_write_start_response(conn, request_id) < 0 ||
2271822720
rpc_write(conn, &scuda_intercept_result, sizeof(cudaError_t)) < 0 ||
2271922721
rpc_write_end(conn) < 0)
2272022722
goto ERROR_0;
2272122723

22724+
std::cout << "DONE CUDAGRAPH" << std::endl;
22725+
2272222726
return 0;
2272322727
ERROR_0:
2272422728
return -1;

rpc.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <sys/socket.h>
22

33
#include "rpc.h"
4+
#include <string.h>
45
#include <iostream>
56
#include <unistd.h>
67

@@ -10,17 +11,19 @@
1011
//
1112
// it is not necessary to call rpc_read_start() if it is the first call in
1213
// the sequence because by convention, the handler owns the read lock on entry.
13-
unsigned int rpc_dispatch(conn_t *conn, int parity) {
14+
int rpc_dispatch(conn_t *conn, int parity) {
1415
if (pthread_mutex_lock(&conn->read_mutex) < 0) {
1516
return -1;
1617
}
1718

18-
while (true) {
19+
while (1) {
1920
std::cout << "rpc_dispatch: waiting for read id" << std::endl;
2021

2122
while (conn->read_id != 0)
2223
pthread_cond_wait(&conn->read_cond, &conn->read_mutex);
2324

25+
std::cout << "reading req id... " << std::endl;
26+
2427
// the read id is zero so it's our turn to read the next int which is the
2528
// request id of the next request.
2629
if (rpc_read(conn, &conn->read_id, sizeof(int)) < 0) {
@@ -33,13 +36,12 @@ unsigned int rpc_dispatch(conn_t *conn, int parity) {
3336
if (conn->read_id % 2 == parity) {
3437
std::cout << "rpc_dispatch: new read id: " << conn->read_id << std::endl;
3538
// this request is the one to be dispatched, read the op and return it
36-
unsigned int op;
37-
if (rpc_read(conn, &op, sizeof(unsigned int)) < 0) {
39+
int op;
40+
if (rpc_read(conn, &op, sizeof(int)) < 0) {
3841
pthread_mutex_unlock(&conn->read_mutex);
3942
return -1;
4043
}
4144

42-
pthread_mutex_unlock(&conn->read_mutex);
4345
return op;
4446
} else {
4547
// this is a response to a request that so signal the update and wait for
@@ -62,6 +64,7 @@ int rpc_read_start(conn_t *conn, int write_id) {
6264
if (pthread_mutex_lock(&conn->read_mutex) < 0)
6365
return -1;
6466

67+
6568
// wait for the active read id to be the request id we are waiting for
6669
while (conn->read_id != write_id)
6770
if (pthread_cond_wait(&conn->read_cond, &conn->read_mutex) < 0)
@@ -71,7 +74,11 @@ int rpc_read_start(conn_t *conn, int write_id) {
7174
}
7275

7376
int rpc_read(conn_t *conn, void *data, size_t size) {
74-
return recv(conn->connfd, data, size, MSG_WAITALL);
77+
int bytes_read = recv(conn->connfd, data, size, MSG_WAITALL);
78+
if (bytes_read == -1) {
79+
printf("recv error: %s\n", strerror(errno));
80+
}
81+
return bytes_read;
7582
}
7683

7784
// rpc_read_end releases the response lock on the given connection.
@@ -99,7 +106,7 @@ int rpc_wait_for_response(conn_t *conn) {
99106
//
100107
// only one request can be active at a time, so this function will take the
101108
// request lock from the connection.
102-
int rpc_write_start_request(conn_t *conn, const unsigned int op) {
109+
int rpc_write_start_request(conn_t *conn, const int op) {
103110
if (pthread_mutex_lock(&conn->write_mutex) < 0) {
104111
#ifdef VERBOSE
105112
std::cout << "rpc_write_start failed due to rpc_open() < 0 || "

rpc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ typedef struct {
1919
int write_iov_count = 0;
2020
} conn_t;
2121

22-
extern unsigned int rpc_dispatch(conn_t *conn, int parity);
22+
extern int rpc_dispatch(conn_t *conn, int parity);
2323
extern int rpc_read_start(conn_t *conn, int write_id);
2424
extern int rpc_read(conn_t *conn, void *data, size_t size);
2525
extern int rpc_read_end(conn_t *conn);
2626

2727
extern int rpc_wait_for_response(conn_t *conn);
2828

29-
extern int rpc_write_start_request(conn_t *conn, const unsigned int op);
29+
extern int rpc_write_start_request(conn_t *conn, const int op);
3030
extern int rpc_write_start_response(conn_t *conn, const int read_id);
3131
extern int rpc_write(conn_t *conn, const void *data, const size_t size);
3232
extern int rpc_write_end(conn_t *conn);

server.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ static void segfault(int sig, siginfo_t *info, void *unused) {
114114
raise(SIGSEGV);
115115
}
116116

117+
conn_t* stored_conn;
118+
119+
void store_conn(const void *conn) {
120+
stored_conn = (conn_t *)conn;
121+
}
122+
117123
typedef struct callBackData {
118124
conn_t *conn;
119125
void (*callback)(void *);
@@ -122,15 +128,53 @@ typedef struct callBackData {
122128

123129
void invoke_host_func(void *data) {
124130
callBackData_t *tmp = (callBackData_t *)(data);
125-
void *res;
131+
void* scuda_intercept_result;
132+
133+
// Validate connection
134+
if (!stored_conn) {
135+
std::cerr << "Error: Connection is NULL in invoke_host_func" << std::endl;
136+
return;
137+
}
138+
139+
printf("Invoking host function %p\n", tmp->callback);
140+
141+
if (rpc_write_start_request(stored_conn, 1) < 0) {
142+
std::cerr << "Error: rpc_write_start_request failed" << std::endl;
143+
return;
144+
}
145+
if (rpc_write(stored_conn, &tmp->callback, sizeof(void *)) < 0) {
146+
std::cerr << "Error: rpc_write failed on callback" << std::endl;
147+
return;
148+
}
149+
if (rpc_write(stored_conn, &tmp->data, sizeof(void *)) < 0) {
150+
std::cerr << "Error: rpc_write failed on data" << std::endl;
151+
return;
152+
}
126153

127-
printf("invoking host function %p\n", tmp->callback);
154+
// Ensure request is fully sent before waiting for response
155+
if (rpc_write_end(stored_conn) < 0) {
156+
std::cerr << "Error: rpc_write_end failed" << std::endl;
157+
return;
158+
}
159+
160+
printf("hereeee %p\n", tmp->callback);
161+
162+
if (rpc_wait_for_response(stored_conn) < 0) {
163+
std::cerr << "Error: rpc_wait_for_response failed" << std::endl;
164+
return;
165+
}
166+
167+
if (rpc_read(stored_conn, &scuda_intercept_result, sizeof(void*)) < 0) {
168+
std::cerr << "Error: rpc_read failed on scuda_intercept_result" << std::endl;
169+
return;
170+
}
171+
172+
if (rpc_read_end(stored_conn) < 0) {
173+
std::cerr << "Error: rpc_read_end failed" << std::endl;
174+
return;
175+
}
128176

129-
if (rpc_write_start_request(tmp->conn, 1) < 0 ||
130-
rpc_write(tmp->conn, &tmp->callback, sizeof(void *)) < 0 ||
131-
rpc_write(tmp->conn, &tmp->data, sizeof(void *)) < 0 ||
132-
rpc_wait_for_response(tmp->conn) < 0 || rpc_read_end(tmp->conn) < 0)
133-
std::cout << "failed to write memory: " << &faulting_address << std::endl;
177+
std::cout << "RESULT IS: " << scuda_intercept_result << std::endl;
134178
}
135179

136180
void append_host_func_ptr(const void *conn, void *ptr) {
@@ -172,7 +216,9 @@ void client_handler(int connfd) {
172216
printf("Client connected.\n");
173217

174218
while (1) {
175-
unsigned int op = rpc_dispatch(&conn, 0);
219+
int op = rpc_dispatch(&conn, 0);
220+
221+
std::cout << "GOT OP: " << op << std::endl;
176222

177223
auto opHandler = get_handler(op);
178224
if (opHandler(&conn) < 0) {
@@ -183,6 +229,7 @@ void client_handler(int connfd) {
183229
if (pthread_mutex_destroy(&conn.read_mutex) < 0 ||
184230
pthread_mutex_destroy(&conn.write_mutex) < 0)
185231
std::cerr << "Error destroying mutex." << std::endl;
232+
186233
close(connfd);
187234
}
188235

0 commit comments

Comments
 (0)