|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from base64 import b64encode |
| 4 | +from collections import defaultdict |
4 | 5 | from collections.abc import Mapping, MutableMapping, Sequence |
| 6 | +from io import BytesIO |
5 | 7 | from typing import Any, Union, overload |
6 | 8 |
|
7 | 9 | from w3lib.util import to_bytes, to_unicode |
@@ -44,23 +46,22 @@ def headers_raw_to_dict(headers_raw: bytes | None) -> HeadersDictOutput | None: |
44 | 46 |
|
45 | 47 | if headers_raw is None: |
46 | 48 | return None |
47 | | - headers = headers_raw.splitlines() |
48 | | - headers_tuples = [header.split(b":", 1) for header in headers] |
49 | 49 |
|
50 | | - result_dict: HeadersDictOutput = {} |
51 | | - for header_item in headers_tuples: |
52 | | - if len(header_item) != 2: |
53 | | - continue |
| 50 | + if not headers_raw: |
| 51 | + return {} |
| 52 | + |
| 53 | + headers = iter(BytesIO(headers_raw).readline, b"") |
| 54 | + result_dict = defaultdict(list) |
54 | 55 |
|
55 | | - item_key = header_item[0].strip() |
56 | | - item_value = header_item[1].strip() |
| 56 | + for header in headers: |
| 57 | + parts = header.split(b":", 1) |
| 58 | + if len(parts) != 2: |
| 59 | + continue |
57 | 60 |
|
58 | | - if item_key in result_dict: |
59 | | - result_dict[item_key].append(item_value) |
60 | | - else: |
61 | | - result_dict[item_key] = [item_value] |
| 61 | + key, value = map(bytes.strip, parts) |
| 62 | + result_dict[key].append(value) |
62 | 63 |
|
63 | | - return result_dict |
| 64 | + return dict(result_dict) |
64 | 65 |
|
65 | 66 |
|
66 | 67 | @overload |
@@ -93,13 +94,25 @@ def headers_dict_to_raw(headers_dict: HeadersDictInput | None) -> bytes | None: |
93 | 94 |
|
94 | 95 | if headers_dict is None: |
95 | 96 | return None |
96 | | - raw_lines = [] |
| 97 | + |
| 98 | + if not headers_dict: |
| 99 | + return b"" |
| 100 | + |
| 101 | + parts = bytearray() |
| 102 | + |
97 | 103 | for key, value in headers_dict.items(): |
98 | 104 | if isinstance(value, bytes): |
99 | | - raw_lines.append(b": ".join([key, value])) |
| 105 | + if parts: |
| 106 | + parts.extend(b"\r\n") |
| 107 | + parts.extend(key + b": " + value) |
| 108 | + |
100 | 109 | elif isinstance(value, (list, tuple)): |
101 | | - raw_lines.extend(b": ".join([key, v]) for v in value) |
102 | | - return b"\r\n".join(raw_lines) |
| 110 | + for v in value: |
| 111 | + if parts: |
| 112 | + parts.extend(b"\r\n") |
| 113 | + parts.extend(key + b": " + v) |
| 114 | + |
| 115 | + return bytes(parts) |
103 | 116 |
|
104 | 117 |
|
105 | 118 | def basic_auth_header( |
|
0 commit comments