Skip to content

Commit 77032df

Browse files
committed
Added async support. callback prototype should probably change
1 parent 6d56345 commit 77032df

File tree

3 files changed

+125
-21
lines changed

3 files changed

+125
-21
lines changed

context.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ context_address(getdns_ContextObject *self, PyObject *args, PyObject *keywds)
989989
PyDictObject *extensions_obj = 0;
990990
void *userarg;
991991
long tid;
992-
char * callback = 0;
992+
char *callback = 0;
993993
PyObject *resp;
994994

995995
if ((context = PyCapsule_GetPointer(self->py_context, "context")) == NULL) {

getdns.c

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include <getdns/getdns.h>
4444
#include <getdns/getdns_ext_libevent.h>
4545
#include <event2/event.h>
46+
#include <pthread.h>
4647
#include "pygetdns.h"
4748

4849

@@ -156,31 +157,30 @@ void
156157
callback_shim(getdns_context *context, getdns_callback_type_t type, getdns_dict *resp,
157158
void *u, getdns_transaction_t tid)
158159
{
159-
PyObject *main_module;
160-
PyObject *main_dict;
161-
PyObject *getdns_runner;
162160
pygetdns_libevent_callback_data *callback_data;
163161
PyObject *response;
162+
PyObject *getdns_runner;
163+
PyGILState_STATE state;
164164

165-
if ((main_module = PyImport_AddModule("__main__")) == 0) {
166-
PyErr_SetString(getdns_error, "No __main__");
167-
/* need to throw an error here */
168-
}
169165
callback_data = (pygetdns_libevent_callback_data *)u;
170-
main_dict = PyModule_GetDict(main_module);
171-
if ((getdns_runner = PyDict_GetItemString(main_dict, callback_data->callback_func)) == 0) {
172-
PyErr_SetString(getdns_error, "callback not found");
173-
/* need to throw an error here XXX */
166+
getdns_runner = callback_data->callback_func;
167+
if (!PyCallable_Check(getdns_runner)) { /* XXX */
168+
printf("callback not runnable\n");
169+
return;
174170
}
175-
/* Python callback prototype: */
176-
/* callback(context, callback_type, response, userarg, tid) */
177-
178171
if ((response = getFullResponse(resp)) == 0) {
179172
PyErr_SetString(getdns_error, "Unable to decode response");
173+
return;
180174
/* need to throw exceptiion XXX */
181175
}
176+
177+
/* Python callback prototype: */
178+
/* callback(context, callback_type, response, userarg, tid) */
179+
state = PyGILState_Ensure();
182180
PyObject_CallFunction(getdns_runner, "OHOsl", context, type, response,
183181
callback_data->userarg, tid);
182+
PyObject_CallFunction(getdns_runner, "s", "asdfasdf");
183+
PyGILState_Release(state);
184184
}
185185

186186

@@ -229,8 +229,29 @@ context_create(PyObject *self, PyObject *args, PyObject *keywds)
229229
}
230230

231231

232+
/*
233+
* called from pthread_create. Pull out the query arguments,
234+
* get the Python callback function from the dictionary for
235+
* __main__
236+
*/
237+
238+
void
239+
marshall_query(pygetdns_async_args_blob *blob)
240+
{
241+
PyObject *ret;
242+
243+
if ((ret = dispatch_query(blob->context_capsule, blob->name,
244+
blob->type, blob->extensions, blob->userarg, blob->tid,
245+
blob->callback)) == 0) {
246+
PyErr_SetString(getdns_error, GETDNS_RETURN_GENERIC_ERROR_TEXT);
247+
pthread_exit(0);
248+
}
249+
}
250+
251+
252+
232253
PyObject *
233-
do_query(PyObject *context_capsule,
254+
dispatch_query(PyObject *context_capsule,
234255
void *name,
235256
uint16_t request_type,
236257
PyDictObject *extensions_obj,
@@ -319,22 +340,21 @@ do_query(PyObject *context_capsule,
319340
if (callback) {
320341
struct event_base *gen_event_base;
321342
int dispatch_ret;
322-
pygetdns_libevent_callback_data callback_data;
343+
pygetdns_libevent_callback_data *callback_data = userarg;
323344

324345
if ((gen_event_base = event_base_new()) == NULL) {
325346
PyErr_SetString(getdns_error, GETDNS_RETURN_GENERIC_ERROR_TEXT);
326347
return NULL;
327348
}
328349

329-
callback_data.callback_func = callback;
330-
callback_data.userarg = userarg;
350+
callback_data->userarg = userarg;
331351
if ((ret = getdns_extension_set_libevent_base(context, gen_event_base)) != GETDNS_RETURN_GOOD) {
332352
PyErr_SetString(getdns_error, GETDNS_RETURN_GENERIC_ERROR_TEXT);
333353
return NULL;
334354
}
335355

336356
if ((ret = getdns_general(context, query_name, request_type,
337-
extensions_dict, (void *)&callback_data,
357+
extensions_dict, (void *)callback_data,
338358
(getdns_transaction_t *)&tid, callback_shim)) != GETDNS_RETURN_GOOD) {
339359
PyErr_SetString(getdns_error, GETDNS_RETURN_GENERIC_ERROR_TEXT);
340360
event_base_free(gen_event_base);
@@ -356,6 +376,73 @@ do_query(PyObject *context_capsule,
356376
}
357377

358378

379+
/*
380+
* there's not many people doing this so it probably
381+
* bears some explanation. If there's a callback argument
382+
* we need to spin off a thread to handle the callback,
383+
* and do it in a way that doesn't make the Python thread
384+
* scheduler barf. Additionally, in order to avoid making
385+
* getdns barf, we need to move data off the stack and onto
386+
* the heap for it to be available to the new thread. This includes
387+
* a pointer to the PyObject representing the user-defined
388+
* callback function.
389+
* So basically we're encapsulating the data so that the
390+
* new thread can use it to recreate the calling environment
391+
*/
392+
393+
394+
PyObject *
395+
do_query(PyObject *context_capsule,
396+
void *name,
397+
uint16_t request_type,
398+
PyDictObject *extensions_obj,
399+
void *userarg,
400+
long tid,
401+
char *callback)
402+
403+
{
404+
if (!callback) {
405+
return dispatch_query(context_capsule, name, request_type, extensions_obj,
406+
userarg, tid, callback);
407+
} else {
408+
PyObject *main_module;
409+
PyObject *main_dict;
410+
PyObject *getdns_runner;
411+
pthread_t runner_thread;
412+
pygetdns_async_args_blob *async_blob;
413+
414+
if ((main_module = PyImport_AddModule("__main__")) == 0) {
415+
PyErr_SetString(getdns_error, "No __main__");
416+
/* need to throw an error here */
417+
}
418+
main_dict = PyModule_GetDict(main_module);
419+
if ((getdns_runner = PyDict_GetItemString(main_dict, callback)) == 0) {
420+
PyErr_SetString(getdns_error, "callback not found");
421+
return NULL;
422+
}
423+
424+
async_blob = (pygetdns_async_args_blob *)malloc(sizeof(pygetdns_async_args_blob));
425+
async_blob->context_capsule = context_capsule;
426+
async_blob->name = malloc(256); /* XXX magic number */
427+
strncpy(async_blob->name, name, strlen(name));
428+
async_blob->type = request_type;
429+
async_blob->extensions = extensions_obj;
430+
async_blob->userarg = userarg;
431+
async_blob->tid = tid;
432+
async_blob->callback = malloc(256); /* XXX magic number */
433+
strncpy(async_blob->callback, callback, strlen(callback));
434+
async_blob->runner = getdns_runner;
435+
async_blob->userarg->callback_func = getdns_runner;
436+
437+
Py_BEGIN_ALLOW_THREADS;
438+
pthread_create(&runner_thread, NULL, (void *)marshall_query, (void *)async_blob);
439+
pthread_detach(runner_thread);
440+
Py_END_ALLOW_THREADS;
441+
return Py_None;
442+
}
443+
}
444+
445+
359446
static PyObject *
360447
cancel_callback(PyObject *self, PyObject *args, PyObject *keywds)
361448
{
@@ -1456,6 +1543,8 @@ initgetdns(void)
14561543
{
14571544
PyObject *g;
14581545

1546+
Py_Initialize();
1547+
PyEval_InitThreads();
14591548
if ((g = Py_InitModule("getdns", getdns_methods)) == NULL)
14601549
return;
14611550
getdns_error = PyErr_NewException("getdns.error", NULL, NULL);

pygetdns.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
PyObject *getdns_error;
4141

4242
typedef struct pygetdns_libevent_callback_data {
43-
char *callback_func;
4443
void *userarg;
44+
PyObject *callback_func;
4545
} pygetdns_libevent_callback_data;
4646

4747

@@ -68,6 +68,17 @@ typedef struct {
6868
char *version_string;
6969
} getdns_ContextObject;
7070

71+
typedef struct pygetdns_async_args_blob {
72+
PyObject *context_capsule;
73+
PyObject *runner;
74+
char *name;
75+
uint16_t type;
76+
PyDictObject *extensions;
77+
pygetdns_libevent_callback_data *userarg;
78+
getdns_transaction_t tid;
79+
char *callback;
80+
} pygetdns_async_args_blob;
81+
7182

7283
int context_init(getdns_ContextObject *self, PyObject *args, PyObject *keywds);
7384
PyObject *context_getattro(PyObject *self, PyObject *nameobj);
@@ -111,3 +122,7 @@ PyObject *context_get_num_pending_requests(PyObject *self, PyObject *args, PyObj
111122
PyObject *context_process_async(PyObject *self, PyObject *args, PyObject *keywds);
112123
getdns_dict *getdnsify_addressdict(PyObject *pydict);
113124
void context_dealloc(getdns_ContextObject *self);
125+
void marshall_query(pygetdns_async_args_blob *blog);
126+
PyObject *dispatch_query(PyObject *context_capsule, void *name, uint16_t request_type,
127+
PyDictObject *extensions_obj, void *userarg, long tid, char *callback);
128+

0 commit comments

Comments
 (0)