Skip to content

Commit b7f8dce

Browse files
authored
Add callback for oauth (@stevenylai, #960)
1 parent 8060bd0 commit b7f8dce

File tree

7 files changed

+244
-0
lines changed

7 files changed

+244
-0
lines changed

docs/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,14 @@ The Python bindings also provide some additional configuration properties:
459459
This callback is served upon calling ``client.poll()`` or ``producer.flush()``. See
460460
https://github.com/edenhill/librdkafka/wiki/Statistics" for more information.
461461

462+
* ``oauth_cb(config_str)``: Callback for retrieving OAuth Bearer token.
463+
Function argument ``config_str`` is a str from config: ``sasl.oauthbearer.config``.
464+
Return value of this callback is expected to be ``(token_str, expiry_time)`` tuple
465+
where ``expiry_time`` is the time in seconds since the epoch as a floating point number.
466+
This callback is useful only when ``sasl.mechanisms=OAUTHBEARER`` is set and
467+
is served to get the initial token before a successful broker connection can be made.
468+
The callback can be triggered by calling ``client.poll()`` or ``producer.flush()``.
469+
462470
* ``on_delivery(kafka.KafkaError, kafka.Message)`` (**Producer**): value is a Python function reference
463471
that is called once for each produced message to indicate the final
464472
delivery result (success or failure).

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The scripts in this directory provide code examples using Confluent's Python cli
1414
* [protobuf_consumer.py](protobuf_consumer.py): DeserializingConsumer with ProtobufDeserializer
1515
* [sasl_producer.py](sasl_producer.py): SerializingProducer with SASL Authentication
1616
* [list_offsets.py](list_offsets.py): List committed offsets and consumer lag for group and topics
17+
* [oauth_producer.py](oauth_producer.py): SerializingProducer with OAuth Authentication (client credentials)
1718

1819
Additional examples for [Confluent Cloud](https://www.confluent.io/confluent-cloud/):
1920

examples/oauth_producer.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2020 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
#
20+
# This uses OAuth client credentials grant:
21+
# https://www.oauth.com/oauth2-servers/access-tokens/client-credentials/
22+
# where client_id and client_secret are passed as HTTP Authorization header
23+
#
24+
25+
import logging
26+
import functools
27+
import argparse
28+
import time
29+
from confluent_kafka import SerializingProducer
30+
from confluent_kafka.serialization import StringSerializer
31+
import requests
32+
33+
34+
def _get_token(args, config):
35+
"""Note here value of config comes from sasl.oauthbearer.config below.
36+
It is not used in this example but you can put arbitrary values to
37+
configure how you can get the token (e.g. which token URL to use)
38+
"""
39+
payload = {
40+
'grant_type': 'client_credentials',
41+
'scope': ' '.join(args.scopes)
42+
}
43+
resp = requests.post(args.token_url,
44+
auth=(args.client_id, args.client_secret),
45+
data=payload)
46+
token = resp.json()
47+
return token['access_token'], time.time() + float(token['expires_in'])
48+
49+
50+
def producer_config(args):
51+
logger = logging.getLogger(__name__)
52+
return {
53+
'bootstrap.servers': args.bootstrap_servers,
54+
'key.serializer': StringSerializer('utf_8'),
55+
'value.serializer': StringSerializer('utf_8'),
56+
'security.protocol': 'sasl_plaintext',
57+
'sasl.mechanisms': 'OAUTHBEARER',
58+
# sasl.oauthbearer.config can be used to pass argument to your oauth_cb
59+
# It is not used in this example since we are passing all the arguments
60+
# from command line
61+
# 'sasl.oauthbearer.config': 'not-used',
62+
'oauth_cb': functools.partial(_get_token, args),
63+
'logger': logger,
64+
}
65+
66+
67+
def delivery_report(err, msg):
68+
"""
69+
Reports the failure or success of a message delivery.
70+
71+
Args:
72+
err (KafkaError): The error that occurred on None on success.
73+
74+
msg (Message): The message that was produced or failed.
75+
76+
Note:
77+
In the delivery report callback the Message.key() and Message.value()
78+
will be the binary format as encoded by any configured Serializers and
79+
not the same object that was passed to produce().
80+
If you wish to pass the original object(s) for key and value to delivery
81+
report callback we recommend a bound callback or lambda where you pass
82+
the objects along.
83+
84+
"""
85+
if err is not None:
86+
print('Delivery failed for User record {}: {}'.format(msg.key(), err))
87+
return
88+
print('User record {} successfully produced to {} [{}] at offset {}'.format(
89+
msg.key(), msg.topic(), msg.partition(), msg.offset()))
90+
91+
92+
def main(args):
93+
topic = args.topic
94+
delimiter = args.delimiter
95+
96+
producer_conf = producer_config(args)
97+
98+
producer = SerializingProducer(producer_conf)
99+
100+
print('Producing records to topic {}. ^C to exit.'.format(topic))
101+
while True:
102+
# Serve on_delivery callbacks from previous calls to produce()
103+
producer.poll(0.0)
104+
try:
105+
msg_data = input(">")
106+
msg = msg_data.split(delimiter)
107+
if len(msg) == 2:
108+
producer.produce(topic=topic, key=msg[0], value=msg[1],
109+
on_delivery=delivery_report)
110+
else:
111+
producer.produce(topic=topic, value=msg[0],
112+
on_delivery=delivery_report)
113+
except KeyboardInterrupt:
114+
break
115+
116+
print('\nFlushing {} records...'.format(len(producer)))
117+
producer.flush()
118+
119+
120+
if __name__ == '__main__':
121+
parser = argparse.ArgumentParser(description="SerializingProducer OAUTH Example"
122+
" with client credentials grant")
123+
parser.add_argument('-b', dest="bootstrap_servers", required=True,
124+
help="Bootstrap broker(s) (host[:port])")
125+
parser.add_argument('-t', dest="topic", default="example_producer_oauth",
126+
help="Topic name")
127+
parser.add_argument('-d', dest="delimiter", default="|",
128+
help="Key-Value delimiter. Defaults to '|'"),
129+
parser.add_argument('--client', dest="client_id", required=True,
130+
help="Client ID for client credentials flow")
131+
parser.add_argument('--secret', dest="client_secret", required=True,
132+
help="Client secret for client credentials flow.")
133+
parser.add_argument('--token-url', dest="token_url", required=True,
134+
help="Token URL.")
135+
parser.add_argument('--scopes', dest="scopes", required=True, nargs='+',
136+
help="Scopes requested from OAuth server.")
137+
138+
main(parser.parse_args())

examples/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pyrsistent==0.16.1;python_version<"3.0"
77
pyrsistent;python_version>"3.0"
88
jsonschema
99
protobuf
10+
requests

src/confluent_kafka/src/confluent_kafka.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,50 @@ static void log_cb (const rd_kafka_t *rk, int level,
15211521
CallState_resume(cs);
15221522
}
15231523

1524+
static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
1525+
void *opaque) {
1526+
Handle *h = opaque;
1527+
PyObject *eo, *result;
1528+
CallState *cs;
1529+
const char *token;
1530+
double expiry;
1531+
char err_msg[2048];
1532+
rd_kafka_resp_err_t err_code;
1533+
1534+
cs = CallState_get(h);
1535+
1536+
eo = Py_BuildValue("s", oauthbearer_config);
1537+
result = PyObject_CallFunctionObjArgs(h->oauth_cb, eo, NULL);
1538+
Py_DECREF(eo);
1539+
1540+
if (!result) {
1541+
goto err;
1542+
}
1543+
if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) {
1544+
Py_DECREF(result);
1545+
PyErr_Format(PyExc_TypeError,
1546+
"expect returned value from oauth_cb "
1547+
"to be (token_str, expiry_time) tuple");
1548+
goto err;
1549+
}
1550+
err_code = rd_kafka_oauthbearer_set_token(h->rk, token,
1551+
(int64_t)(expiry * 1000),
1552+
"", NULL, 0, err_msg,
1553+
sizeof(err_msg));
1554+
Py_DECREF(result);
1555+
if (err_code) {
1556+
PyErr_Format(PyExc_ValueError, "%s", err_msg);
1557+
goto err;
1558+
}
1559+
goto done;
1560+
1561+
err:
1562+
CallState_crash(cs);
1563+
rd_kafka_yield(h->rk);
1564+
done:
1565+
CallState_resume(cs);
1566+
}
1567+
15241568
/****************************************************************************
15251569
*
15261570
*
@@ -1949,6 +1993,25 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype,
19491993
Py_XDECREF(ks8);
19501994
Py_DECREF(ks);
19511995
continue;
1996+
} else if (!strcmp(k, "oauth_cb")) {
1997+
if (!PyCallable_Check(vo)) {
1998+
PyErr_SetString(PyExc_TypeError,
1999+
"expected oauth_cb property "
2000+
"as a callable function");
2001+
goto inner_err;
2002+
}
2003+
if (h->oauth_cb) {
2004+
Py_DECREF(h->oauth_cb);
2005+
h->oauth_cb = NULL;
2006+
}
2007+
2008+
if (vo != Py_None) {
2009+
h->oauth_cb = vo;
2010+
Py_INCREF(h->oauth_cb);
2011+
}
2012+
Py_XDECREF(ks8);
2013+
Py_DECREF(ks);
2014+
continue;
19522015
}
19532016

19542017
/* Special handling for certain config keys. */
@@ -2019,6 +2082,9 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype,
20192082
rd_kafka_conf_set_log_cb(conf, log_cb);
20202083
}
20212084

2085+
if (h->oauth_cb)
2086+
rd_kafka_conf_set_oauthbearer_token_refresh_cb(conf, oauth_cb);
2087+
20222088
rd_kafka_conf_set_opaque(conf, h);
20232089

20242090
#ifdef WITH_PY_TSS

src/confluent_kafka/src/confluent_kafka.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ typedef struct {
236236
rd_kafka_type_t type; /* Producer or consumer */
237237

238238
PyObject *logger;
239+
PyObject *oauth_cb;
239240

240241
union {
241242
/**

tests/test_misc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,35 @@ def test_throttle_event_types():
130130
assert str(throttle_event) == "broker/0 throttled for 10000 ms"
131131

132132

133+
# global variable for oauth_cb call back function
134+
seen_oauth_cb = False
135+
136+
137+
def test_oauth_cb():
138+
""" Tests oauth_cb. """
139+
140+
def oauth_cb(oauth_config):
141+
global seen_oauth_cb
142+
seen_oauth_cb = True
143+
assert oauth_config == 'oauth_cb'
144+
return 'token', time.time() + 300.0
145+
146+
conf = {'group.id': 'test',
147+
'security.protocol': 'sasl_plaintext',
148+
'sasl.mechanisms': 'OAUTHBEARER',
149+
'socket.timeout.ms': '100',
150+
'session.timeout.ms': 1000, # Avoid close() blocking too long
151+
'sasl.oauthbearer.config': 'oauth_cb',
152+
'oauth_cb': oauth_cb
153+
}
154+
155+
kc = confluent_kafka.Consumer(**conf)
156+
157+
while not seen_oauth_cb:
158+
kc.poll(timeout=1)
159+
kc.close()
160+
161+
133162
def skip_interceptors():
134163
# Run interceptor test if monitoring-interceptor is found
135164
for path in ["/usr/lib", "/usr/local/lib", "staging/libs", "."]:

0 commit comments

Comments
 (0)