diff --git a/cachecontrol/adapter.py b/cachecontrol/adapter.py index e739a480..9a326349 100644 --- a/cachecontrol/adapter.py +++ b/cachecontrol/adapter.py @@ -5,6 +5,7 @@ import functools import types +import weakref import zlib from typing import TYPE_CHECKING, Any, Collection, Mapping @@ -128,19 +129,25 @@ def build_response( # type: ignore[override] response._fp = CallbackFileWrapper( # type: ignore[assignment] response._fp, # type: ignore[arg-type] functools.partial( - self.controller.cache_response, request, response + self.controller.cache_response, request, weakref.ref(response) ), ) if response.chunked: - super_update_chunk_length = response._update_chunk_length + super_update_chunk_length = response.__class__._update_chunk_length - def _update_chunk_length(self: HTTPResponse) -> None: - super_update_chunk_length() + def _update_chunk_length( + weak_self: weakref.ReferenceType[HTTPResponse], + ) -> None: + self = weak_self() + if self is None: + return + + super_update_chunk_length(self) if self.chunk_left == 0: self._fp._close() # type: ignore[union-attr] - response._update_chunk_length = types.MethodType( # type: ignore[method-assign] - _update_chunk_length, response + response._update_chunk_length = functools.partial( # type: ignore[method-assign] + _update_chunk_length, weakref.ref(response) ) resp: Response = super().build_response(request, response) diff --git a/cachecontrol/controller.py b/cachecontrol/controller.py index f826aec0..4e251c8f 100644 --- a/cachecontrol/controller.py +++ b/cachecontrol/controller.py @@ -12,6 +12,7 @@ import logging import re import time +import weakref from email.utils import parsedate_tz from typing import TYPE_CHECKING, Collection, Mapping @@ -323,7 +324,7 @@ def _cache_set( def cache_response( self, request: PreparedRequest, - response: HTTPResponse, + response_or_ref: HTTPResponse | weakref.ReferenceType[HTTPResponse], body: bytes | None = None, status_codes: Collection[int] | None = None, ) -> None: @@ -332,6 +333,16 @@ def cache_response( This assumes a requests Response object. """ + if isinstance(response_or_ref, weakref.ReferenceType): + response = response_or_ref() + if response is None: + # The weakref can be None only in case the user used streamed request + # and did not consume or close it, and holds no reference to requests.Response. + # In such case, we don't want to cache the response. + return + else: + response = response_or_ref + # From httplib2: Don't cache 206's since we aren't going to # handle byte range requests cacheable_status_codes = status_codes or self.cacheable_status_codes diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 2614e098..7fd0c59e 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +import gc +import weakref from unittest import mock -import pytest +import pytest from requests import Session + from cachecontrol.adapter import CacheControlAdapter from cachecontrol.cache import DictCache from cachecontrol.wrapper import CacheControl @@ -65,3 +68,22 @@ def test_close(self): sess.close() assert cache.close.called + + def test_do_not_leak_response(self, url, sess): + resp = sess.get(url + "stream", stream=True) + resp.raise_for_status() + r1_weak = weakref.ref(resp.raw) + + # This is a mis-use of requests, becase we should either consume + # the body, or call .close(). + # But requests without cachecontrol handle this anyway, because + # urllib3.response.HTTPResponse has a __del__ finalizer on it that closes it + # once there are no more references to it. + # We should not break this. + + resp = None + # Below this point, it should be closed because there are no more references + # to the session response. + + r1 = r1_weak() + assert r1 is None or r1.closed