|
7 | 7 | import os |
8 | 8 | import re |
9 | 9 | import time |
| 10 | +import struct |
10 | 11 |
|
11 | 12 | from django.core.exceptions import ImproperlyConfigured |
12 | 13 |
|
@@ -53,7 +54,22 @@ def encode_connection_string(fields): |
53 | 54 | '%s=%s' % (k, encode_value(v)) |
54 | 55 | for k, v in fields.items() |
55 | 56 | ) |
| 57 | +def prepare_token_for_odbc(token): |
| 58 | + """ |
| 59 | + Will prepare token for passing it to the odbc driver, as it expects |
| 60 | + bytes and not a string |
| 61 | + :param token: |
| 62 | + :return: packed binary byte representation of token string |
| 63 | + """ |
| 64 | + if not isinstance(token, str): |
| 65 | + raise TypeError("Invalid token format provided.") |
56 | 66 |
|
| 67 | + tokenstr = token.encode() |
| 68 | + exptoken = b"" |
| 69 | + for i in tokenstr: |
| 70 | + exptoken += bytes({i}) |
| 71 | + exptoken += bytes(1) |
| 72 | + return struct.pack("=i", len(exptoken)) + exptoken |
57 | 73 |
|
58 | 74 | def encode_value(v): |
59 | 75 | """If the value contains a semicolon, or starts with a left curly brace, |
@@ -294,7 +310,7 @@ def get_new_connection(self, conn_params): |
294 | 310 | cstr_parts['UID'] = user |
295 | 311 | if 'Authentication=ActiveDirectoryInteractive' not in options_extra_params: |
296 | 312 | cstr_parts['PWD'] = password |
297 | | - else: |
| 313 | + elif 'TOKEN' not in conn_params: |
298 | 314 | if ms_drivers.match(driver) and 'Authentication=ActiveDirectoryMsi' not in options_extra_params: |
299 | 315 | cstr_parts['Trusted_Connection'] = trusted_connection |
300 | 316 | else: |
@@ -324,11 +340,17 @@ def get_new_connection(self, conn_params): |
324 | 340 | conn = None |
325 | 341 | retry_count = 0 |
326 | 342 | need_to_retry = False |
| 343 | + args = { |
| 344 | + 'unicode_results': unicode_results, |
| 345 | + 'timeout': timeout, |
| 346 | + } |
| 347 | + if 'TOKEN' in conn_params: |
| 348 | + args['attrs_before'] = { |
| 349 | + 1256: prepare_token_for_odbc(conn_params['TOKEN']) |
| 350 | + } |
327 | 351 | while conn is None: |
328 | 352 | try: |
329 | | - conn = Database.connect(connstr, |
330 | | - unicode_results=unicode_results, |
331 | | - timeout=timeout) |
| 353 | + conn = Database.connect(connstr, **args) |
332 | 354 | except Exception as e: |
333 | 355 | for error_number in self._transient_error_numbers: |
334 | 356 | if error_number in e.args[1]: |
|
0 commit comments